diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a0f62ebd..fd85de421 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements * Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688) * Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684) +* Add efSearch as a query parameter [#1707](https://github.com/opensearch-project/k-NN/pull/1707) * Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696) * Add validation for pq m parameter before training starts [#1713](https://github.com/opensearch-project/k-NN/pull/1713) ### Bug Fixes diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 595fa6fea..32c6e8634 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -147,6 +147,7 @@ if ("${WIN32}" STREQUAL "") add_executable( jni_test tests/faiss_wrapper_test.cpp + tests/faiss_wrapper_unit_test.cpp tests/faiss_util_test.cpp tests/nmslib_wrapper_test.cpp tests/test_util.cpp diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 958eca8ac..aa747862a 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -45,17 +45,27 @@ namespace knn_jni { // Sets the sharedIndexState for an index void SetSharedIndexState(jlong indexPointerJ, jlong shareIndexStatePointerJ); - // Execute a query against the index located in memory at indexPointerJ. - // - // Return an array of KNNQueryResults + /** + * Execute a query against the index located in memory at indexPointerJ + * + * Parameters: + * methodParamsJ: introduces a map to have additional method parameters + * + * Return an array of KNNQueryResults + */ jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ); - - // Execute a query against the index located in memory at indexPointerJ along with Filters - // - // Return an array of KNNQueryResults + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ); + + /** + * Execute a query against the index located in memory at indexPointerJ along with Filters + * + * Parameters: + * methodParamsJ: introduces a map to have additional method parameters + * + * Return an array of KNNQueryResults + */ jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); // Free the index located in memory at indexPointerJ diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index e16677db7..0453864e4 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -69,18 +69,18 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt /* * Class: org_opensearch_knn_jni_FaissService * Method: queryIndex - * Signature: (J[FI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; + * Signature: (J[FI[Ljava/util/MapI)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex - (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jintArray); /* * Class: org_opensearch_knn_jni_FaissService * Method: queryIndexWithFilter - * Signature: (J[FI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; + * Signature: (J[FI[JLjava/util/MapI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jlongArray, jint, jintArray); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 5a0910d9a..bd556ad8d 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -72,6 +72,9 @@ void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, fa // Concerts the FilterIds to BitMap void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector); +// Gets efSearch from algo parameters +int getQueryEfSearch(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map methodParams, int defaultEfSearch); + std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap); // Check if a loaded index is an IVFPQ index with l2 space type @@ -296,12 +299,12 @@ void knn_jni::faiss_wrapper::SetSharedIndexState(jlong indexPointerJ, jlong shar } jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0, parentIdsJ); + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ) { + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, nullptr, 0, parentIdsJ); } jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); @@ -313,6 +316,11 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter throw std::runtime_error("Invalid pointer to index"); } + std::unordered_map methodParams; + if (methodParamsJ != nullptr) { + methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); + } + // The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from // the query point std::vector dis(kJ); @@ -340,9 +348,8 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); if(hnswReader) { - // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default - // value of ef_search = 16 which will then be used. - hnswParams.efSearch = hnswReader->hnsw.efSearch; + // Query param efsearch supersedes ef_search provided during index setting. + hnswParams.efSearch = getQueryEfSearch(env, jniUtil, methodParams, hnswReader->hnsw.efSearch); hnswParams.sel = idSelector.get(); if (parentIdsJ != nullptr) { idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); @@ -371,12 +378,13 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter std::unique_ptr idGrouper; std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); - if(hnswReader!= nullptr && parentIdsJ != nullptr) { - // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default - // value of ef_search = 16 which will then be used. - hnswParams.efSearch = hnswReader->hnsw.efSearch; - idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); - hnswParams.grp = idGrouper.get(); + if(hnswReader!= nullptr) { + // Query param efseatch supersedes ef_search provided during index setting. + hnswParams.efSearch = getQueryEfSearch(env, jniUtil, methodParams, hnswReader->hnsw.efSearch); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } searchParameters = &hnswParams; } try { @@ -409,6 +417,18 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter return results; } +int getQueryEfSearch(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map methodParams, int defaultEfSearch) { + if (methodParams.empty()) { + return defaultEfSearch; + } + auto efSearchIt = methodParams.find(knn_jni::EF_SEARCH); + if (efSearchIt != methodParams.end()) { + return jniUtil->ConvertJavaObjectToCppInteger(env, methodParams[knn_jni::EF_SEARCH]); + } + + return defaultEfSearch; +} + void knn_jni::faiss_wrapper::Free(jlong indexPointer) { auto *indexWrapper = reinterpret_cast(indexPointer); delete indexWrapper; diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 0aa51987d..57353f9e1 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -109,10 +109,10 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex(JNIEnv * env, jclass cls, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, parentIdsJ); + return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); @@ -121,10 +121,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd } JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 4cd3b319e..09b30c49a 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -245,6 +245,10 @@ TEST(FaissQueryIndexTest, BasicAssertions) { // Define query data int k = 10; + int efSearch = 20; + std::unordered_map methodParams; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + int numQueries = 100; std::vector> queries; @@ -266,6 +270,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) { // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; + auto methodParamsJ = reinterpret_cast(&methodParams); for (auto query : queries) { std::unique_ptr *>> results( @@ -273,7 +278,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::QueryIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), k, nullptr))); + reinterpret_cast(&query), k, methodParamsJ, nullptr))); ASSERT_EQ(k, results->size()); @@ -282,6 +287,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) { delete it; } } + std::cout << "Test end"; } //Test for a bug reported in https://github.com/opensearch-project/k-NN/issues/1435 @@ -339,7 +345,7 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { knn_jni::faiss_wrapper::QueryIndex_WithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), k, + reinterpret_cast(&query), k, nullptr, reinterpret_cast(&bitmap), 0, nullptr))); ASSERT_TRUE(results->size() <= filterIds.size()); @@ -397,20 +403,20 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) { auto createdIndexWithData = test_util::FaissAddData(createdIndex.get(), ids, vectors); + int efSearch = 100; + std::unordered_map methodParams; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; - EXPECT_CALL(mockJNIUtil, - GetJavaIntArrayLength( - jniEnv, reinterpret_cast(&parentIds))) - .WillRepeatedly(Return(parentIds.size())); for (auto query : queries) { std::unique_ptr *>> results( reinterpret_cast *> *>( knn_jni::faiss_wrapper::QueryIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), k, + reinterpret_cast(&query), k, reinterpret_cast(&methodParams), reinterpret_cast(&parentIds)))); // Even with k 20, result should have only 10 which is total number of groups diff --git a/jni/tests/faiss_wrapper_unit_test.cpp b/jni/tests/faiss_wrapper_unit_test.cpp new file mode 100644 index 000000000..ea9131dd7 --- /dev/null +++ b/jni/tests/faiss_wrapper_unit_test.cpp @@ -0,0 +1,251 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#include "faiss_wrapper.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "jni_util.h" +#include "jni.h" +#include "test_util.h" +#include "faiss/IndexHNSW.h" +#include "faiss/IndexIDMap.h" + +using ::testing::NiceMock; + +using idx_t = faiss::idx_t; + +struct MockIndex : faiss::IndexHNSW { + explicit MockIndex(idx_t d) : faiss::IndexHNSW(d, 32) { + } +}; + + +struct MockIdMap : faiss::IndexIDMap { + mutable idx_t nCalled; + mutable const float *xCalled; + mutable idx_t kCalled; + mutable float *distancesCalled; + mutable idx_t *labelsCalled; + mutable const faiss::SearchParametersHNSW *paramsCalled; + + explicit MockIdMap(MockIndex *index) : faiss::IndexIDMapTemplate(index) { + } + + void search( + idx_t n, + const float *x, + idx_t k, + float *distances, + idx_t *labels, + const faiss::SearchParameters *params) const override { + nCalled = n; + xCalled = x; + kCalled = k; + distancesCalled = distances; + labelsCalled = labels; + paramsCalled = dynamic_cast(params); + } + + void resetMock() const { + nCalled = 0; + xCalled = nullptr; + kCalled = 0; + distancesCalled = nullptr; + labelsCalled = nullptr; + paramsCalled = nullptr; + } +}; + +struct QueryIndexHNSWTestInput { + string description; + int k; + int efSearch; + int filterIdType; + bool filterIdsPresent; + bool parentIdsPresent; +}; + + + +class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam { +public: + FaissWrappeterParametrizedTestFixture() : index_(3), id_map_(&index_) { + index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere + }; + +protected: + MockIndex index_; + MockIdMap id_map_; +}; + +namespace query_index_test { + + std::unordered_map methodParams; + + + TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexHNSWTests) { + //Given + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + + QueryIndexHNSWTestInput const &input = GetParam(); + float query[] = {1.2, 2.3, 3.4}; + + int efSearch = input.efSearch; + int expectedEfSearch = 100; //default set in mock + std::unordered_map methodParams; + if (efSearch != -1) { + expectedEfSearch = input.efSearch; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + } + + std::vector *parentIdPtr = nullptr; + if (input.parentIdsPresent) { + std::vector parentId; + parentId.reserve(2); + parentId.push_back(1); + parentId.push_back(2); + parentIdPtr = &parentId; + + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(parentIdPtr))) + .WillOnce(testing::Return(parentId.size())); + + EXPECT_CALL(mockJNIUtil, + GetIntArrayElements( + jniEnv, reinterpret_cast(parentIdPtr), nullptr)) + .WillOnce(testing::Return(new int[2]{1, 2})); + } + + // When + knn_jni::faiss_wrapper::QueryIndex( + &mockJNIUtil, jniEnv, + reinterpret_cast(&id_map_), + reinterpret_cast(&query), input.k, reinterpret_cast(&methodParams), + reinterpret_cast(parentIdPtr)); + + //Then + int actualEfSearch = id_map_.paramsCalled->efSearch; + // Asserting the captured argument + EXPECT_EQ(input.k, id_map_.kCalled); + EXPECT_EQ(expectedEfSearch, actualEfSearch); + if (input.parentIdsPresent) { + faiss::IDGrouper *grouper = id_map_.paramsCalled->grp; + EXPECT_TRUE(grouper != nullptr); + } + + id_map_.resetMock(); + } + + INSTANTIATE_TEST_CASE_P( + QueryIndexHNSWTests, + FaissWrappeterParametrizedTestFixture, + ::testing::Values( + QueryIndexHNSWTestInput{"algoParams present, parent absent", 10, 200, 0, false, false}, + QueryIndexHNSWTestInput{"algoParams absent, parent absent", 10, -1, 0, false, false}, + QueryIndexHNSWTestInput{"algoParams present, parent present", 10, 200, 0, false, true}, + QueryIndexHNSWTestInput{"algoParams absent, parent present", 10, -1, 0, false, true} + ) + ); +} + +namespace query_index_with_filter_test { + + TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexWithFilterHNSWTests) { + //Given + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + + QueryIndexHNSWTestInput const &input = GetParam(); + float query[] = {1.2, 2.3, 3.4}; + + std::vector *parentIdPtr = nullptr; + if (input.parentIdsPresent) { + std::vector parentId; + parentId.reserve(2); + parentId.push_back(1); + parentId.push_back(2); + parentIdPtr = &parentId; + + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(parentIdPtr))) + .WillOnce(testing::Return(parentId.size())); + + EXPECT_CALL(mockJNIUtil, + GetIntArrayElements( + jniEnv, reinterpret_cast(parentIdPtr), nullptr)) + .WillOnce(testing::Return(new int[2]{1, 2})); + } + + std::vector *filterptr = nullptr; + if (input.filterIdsPresent) { + std::vector filter; + filter.reserve(2); + filter.push_back(1); + filter.push_back(2); + filterptr = &filter; + } + + int efSearch = input.efSearch; + int expectedEfSearch = 100; //default set in mock + std::unordered_map methodParams; + if (efSearch != -1) { + expectedEfSearch = input.efSearch; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + } + + // When + knn_jni::faiss_wrapper::QueryIndex_WithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&id_map_), + reinterpret_cast(&query), input.k, reinterpret_cast(&methodParams), + reinterpret_cast(filterptr), + input.filterIdType, + reinterpret_cast(parentIdPtr)); + + //Then + int actualEfSearch = id_map_.paramsCalled->efSearch; + // Asserting the captured argument + EXPECT_EQ(input.k, id_map_.kCalled); + EXPECT_EQ(expectedEfSearch, actualEfSearch); + if (input.parentIdsPresent) { + faiss::IDGrouper *grouper = id_map_.paramsCalled->grp; + EXPECT_TRUE(grouper != nullptr); + } + if (input.filterIdsPresent) { + faiss::IDSelector *sel = id_map_.paramsCalled->sel; + EXPECT_TRUE(sel != nullptr); + } + id_map_.resetMock(); + } + + INSTANTIATE_TEST_CASE_P( + QueryIndexWithFilterHNSWTests, + FaissWrappeterParametrizedTestFixture, + ::testing::Values( + QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent", 10, 200, 0, false, false}, + QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10, 200, 1, false, false}, + QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present", 10, -1, 0, true, false}, + QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10, -1, 1, true, false}, + QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent", 10, 200, 0, false, true}, + QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent, filter type 1", 10, 150, 1, false, true}, + QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present", 10, -1, 0, true, true}, + QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present, filter type 1",10, -1, 1, true, true} + ) + ); +} diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java new file mode 100644 index 000000000..7a0ed3d56 --- /dev/null +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.bwc; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; + +public class QueryANNIT extends AbstractRestartUpgradeTestCase { + + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 5; + private static final int K = 5; + private static final Integer EF_SEARCH = 10; + private static final int NUM_DOCS = 10; + private static final String ALGORITHM = "hnsw"; + + public void testQueryANN() throws Exception { + if (isRunningAgainstOldCluster()) { + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS, ALGORITHM, FAISS_NAME)); + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, 0, NUM_DOCS); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + } else { + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K, Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH)); + deleteKNNIndex(testIndex); + } + } +} diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java new file mode 100644 index 000000000..080e63241 --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.bwc; + +import java.util.Map; + +import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; + +public class QueryANNIT extends AbstractRollingUpgradeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 5; + private static final int K = 5; + private static final Integer EF_SEARCH = 10; + private static final int NUM_DOCS = 10; + private static final String ALGORITHM = "hnsw"; + + public void testQueryANNIT() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + switch (getClusterType()) { + case OLD: + createKnnIndex( + testIndex, + getKNNDefaultIndexSettings(), + createKnnIndexMapping(TEST_FIELD, DIMENSIONS, ALGORITHM, FAISS_NAME) + ); + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, 0, NUM_DOCS); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + break; + case MIXED: + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + break; + case UPGRADED: + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K, Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH)); + deleteKNNIndex(testIndex); + } + } +} diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index d284015af..dd7b04f45 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -25,6 +25,7 @@ public class KNNConstants { public static final String VECTOR = "vector"; public static final String K = "k"; public static final String TYPE_KNN_VECTOR = "knn_vector"; + public static final String METHOD_PARAMETER = "method_parameters"; public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search"; public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction"; public static final String METHOD_PARAMETER_M = "m"; diff --git a/src/main/java/org/opensearch/knn/engine/method/DefaultHnswContext.java b/src/main/java/org/opensearch/knn/engine/method/DefaultHnswContext.java new file mode 100644 index 000000000..bde219e8f --- /dev/null +++ b/src/main/java/org/opensearch/knn/engine/method/DefaultHnswContext.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.engine.method; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.knn.index.Parameter; +import org.opensearch.knn.index.query.request.MethodParameter; + +import java.util.Map; + +/** + * Default HNSW context for all engines. Have a different implementation if engine context differs. + */ +public final class DefaultHnswContext implements EngineSpecificMethodContext { + + private final Map> supportedMethodParameters = ImmutableMap.>builder() + .put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true)) + .build(); + + @Override + public Map> supportedMethodParameters() { + return supportedMethodParameters; + } +} diff --git a/src/main/java/org/opensearch/knn/engine/method/EngineSpecificMethodContext.java b/src/main/java/org/opensearch/knn/engine/method/EngineSpecificMethodContext.java new file mode 100644 index 000000000..68d8eacab --- /dev/null +++ b/src/main/java/org/opensearch/knn/engine/method/EngineSpecificMethodContext.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.engine.method; + +import org.opensearch.knn.index.Parameter; + +import java.util.Collections; +import java.util.Map; + +/** + * Holds context related to a method for a particular engine + * Each engine can have a specific set of parameters that it supports during index and build time. This context holds + * the information for each engine method combination. + * + * TODO: Move KnnMethod in here + */ +public interface EngineSpecificMethodContext { + + Map> supportedMethodParameters(); + + EngineSpecificMethodContext EMPTY = Collections::emptyMap; +} diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 208da2ea9..4e89a46a3 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -20,6 +20,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.request.MethodParameter; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -45,14 +46,25 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0; - private static final Map minimalRequiredVersionMap = new HashMap() { - { - put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); - put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); - put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); - put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); + private static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); + + private static Map initializeMinimalRequiredVersionMap() { + final Map versionMap = new HashMap<>() { + { + put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); + put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); + put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); + put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); + } + }; + + for (final MethodParameter methodParameter : MethodParameter.values()) { + if (methodParameter.getVersion() != null) { + versionMap.put(methodParameter.getName(), methodParameter.getVersion()); + } } - }; + return Collections.unmodifiableMap(versionMap); + } /** * Determines the size of a file on disk in kilobytes diff --git a/src/main/java/org/opensearch/knn/index/MethodComponent.java b/src/main/java/org/opensearch/knn/index/MethodComponent.java index 256d55ee5..cc545f66b 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponent.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponent.java @@ -19,12 +19,12 @@ import org.opensearch.knn.index.util.IndexHyperParametersUtil; import org.opensearch.knn.training.VectorSpaceInfo; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.function.BiFunction; +import static org.opensearch.knn.validation.ParameterValidator.validateParameters; + /** * MethodComponent defines the structure of an individual component that can make up an index */ @@ -75,32 +75,7 @@ public Map getAsMap(MethodComponentContext methodComponentContex */ public ValidationException validate(MethodComponentContext methodComponentContext) { Map providedParameters = methodComponentContext.getParameters(); - List errorMessages = new ArrayList<>(); - - if (providedParameters == null) { - return null; - } - - ValidationException parameterValidation; - for (Map.Entry parameter : providedParameters.entrySet()) { - if (!parameters.containsKey(parameter.getKey())) { - errorMessages.add(String.format("Invalid parameter for method \"%s\".", getName())); - continue; - } - - parameterValidation = parameters.get(parameter.getKey()).validate(parameter.getValue()); - if (parameterValidation != null) { - errorMessages.addAll(parameterValidation.validationErrors()); - } - } - - if (errorMessages.isEmpty()) { - return null; - } - - ValidationException validationException = new ValidationException(); - validationException.addValidationErrors(errorMessages); - return validationException; + return validateParameters(parameters, providedParameters); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java index 3146cd33e..a02c090b1 100644 --- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -20,6 +20,7 @@ import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Map; import java.util.Optional; /** @@ -42,6 +43,7 @@ public static class CreateQueryRequest { private float[] vector; private byte[] byteVector; private VectorDataType vectorDataType; + private Map methodParameters; private Integer k; private Float radius; private QueryBuilder filter; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 0862b2d93..d123cc149 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -5,10 +5,8 @@ package org.opensearch.knn.index.query; -import java.util.Arrays; -import java.util.Objects; - import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Getter; import lombok.Setter; import org.apache.lucene.search.BooleanClause; @@ -23,26 +21,29 @@ import org.opensearch.knn.index.KNNSettings; import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; /** * Custom KNN query. Query is used for KNNEngine's that create their own custom segment files. These files need to be * loaded and queried in a custom manner throughout the query path. */ +@Getter +@Builder +@AllArgsConstructor public class KNNQuery extends Query { private final String field; private final float[] queryVector; private int k; + private Map methodParameters; private final String indexName; - @Getter @Setter private Query filterQuery; - @Getter private BitSetProducer parentsFilter; - @Getter - private Float radius = null; - @Getter + private Float radius; private Context context; public KNNQuery( @@ -123,22 +124,6 @@ public KNNQuery filterQuery(Query filterQuery) { return this; } - public String getField() { - return this.field; - } - - public float[] getQueryVector() { - return this.queryVector; - } - - public int getK() { - return this.k; - } - - public String getIndexName() { - return this.indexName; - } - /** * Constructs Weight implementation for this query * @@ -183,7 +168,17 @@ public String toString(String field) { @Override public int hashCode() { - return Objects.hash(field, Arrays.hashCode(queryVector), k, indexName, filterQuery); + return Objects.hash( + field, + Arrays.hashCode(queryVector), + k, + indexName, + filterQuery, + context, + parentsFilter, + radius, + methodParameters + ); } @Override @@ -192,10 +187,15 @@ public boolean equals(Object other) { } private boolean equalsTo(KNNQuery other) { + if (other == this) return true; return Objects.equals(field, other.field) && Arrays.equals(queryVector, other.queryVector) && Objects.equals(k, other.k) + && Objects.equals(methodParameters, other.methodParameters) + && Objects.equals(radius, other.radius) + && Objects.equals(context, other.context) && Objects.equals(indexName, other.indexName) + && Objects.equals(parentsFilter, other.parentsFilter) && Objects.equals(filterQuery, other.filterQuery); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index cef65307a..efc73496a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -5,47 +5,61 @@ package org.opensearch.knn.index.query; -import java.io.IOException; -import java.util.Arrays; - -import java.util.List; -import java.util.Objects; - +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.StringUtils; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.opensearch.common.ValidationException; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorQueryType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.parser.MethodParametersParser; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; -import org.apache.lucene.search.Query; -import org.opensearch.core.ParseField; -import org.opensearch.core.common.ParsingException; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.index.query.AbstractQueryBuilder; -import org.opensearch.index.query.QueryShardContext; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; +import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; +import static org.opensearch.knn.validation.ParameterValidator.validateParameters; /** * Helper class to build the KNN query */ +// The builder validates the member variables so access to the constructor is prohibited to not accidentally bypass validations +@AllArgsConstructor(access = AccessLevel.PRIVATE) @Log4j2 public class KNNQueryBuilder extends AbstractQueryBuilder { private static ModelDao modelDao; @@ -56,6 +70,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE); public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE); + public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH); + public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER); public static final int K_MAX = 10000; /** * The name for the knn query @@ -66,18 +82,27 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { */ private final String fieldName; private final float[] vector; - private int k = 0; - private Float maxDistance = null; - private Float minScore = null; + @Getter + private int k; + @Getter + private Float maxDistance; + @Getter + private Float minScore; + @Getter + private Map methodParameters; + @Getter private QueryBuilder filter; - private boolean ignoreUnmapped = false; + @Getter + private boolean ignoreUnmapped; /** * Constructs a new query with the given field name and vector * * @param fieldName Name of the field * @param vector Array of floating points + * @deprecated Use {@code {@link KNNQueryBuilder.Builder}} instead */ + @Deprecated public KNNQueryBuilder(String fieldName, float[] vector) { if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); @@ -93,61 +118,135 @@ public KNNQueryBuilder(String fieldName, float[] vector) { } /** - * Builder method for k - * - * @param k K nearest neighbours for the given vector + * lombok SuperBuilder annotation requires a builder annotation on parent class to work well + * {@link AbstractQueryBuilder#boost()} and {@link AbstractQueryBuilder#queryName()} both need to be called + * A custom builder helps with the calls to the parent class, simultaneously addressing the problem of telescoping + * constructors in this class. */ - public KNNQueryBuilder k(Integer k) { - if (k == null) { - throw new IllegalArgumentException(String.format("[%s] requires k to be set", NAME)); + public static class Builder { + private String fieldName; + private float[] vector; + private Integer k; + private Map methodParameters; + private Float maxDistance; + private Float minScore; + private QueryBuilder filter; + private boolean ignoreUnmapped; + private String queryName; + private float boost = DEFAULT_BOOST; + + private Builder() {} + + public Builder fieldName(String fieldName) { + this.fieldName = fieldName; + return this; } - validateSingleQueryType(k, maxDistance, minScore); - if (k <= 0 || k > K_MAX) { - throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, K_MAX)); + + public Builder vector(float[] vector) { + this.vector = vector; + return this; } - this.k = k; - return this; - } - /** - * Builder method for maxDistance - * - * @param maxDistance the maxDistance threshold for the nearest neighbours - */ - public KNNQueryBuilder maxDistance(Float maxDistance) { - if (maxDistance == null) { - throw new IllegalArgumentException(String.format("[%s] requires maxDistance to be set", NAME)); + public Builder k(Integer k) { + this.k = k; + return this; } - validateSingleQueryType(k, maxDistance, minScore); - this.maxDistance = maxDistance; - return this; - } - /** - * Builder method for minScore - * - * @param minScore the minScore threshold for the nearest neighbours - */ - public KNNQueryBuilder minScore(Float minScore) { - if (minScore == null) { - throw new IllegalArgumentException(String.format("[%s] requires minScore to be set", NAME)); + public Builder methodParameters(Map methodParameters) { + this.methodParameters = methodParameters; + return this; + } + + public Builder maxDistance(Float maxDistance) { + this.maxDistance = maxDistance; + return this; + } + + public Builder minScore(Float minScore) { + this.minScore = minScore; + return this; + } + + public Builder ignoreUnmapped(boolean ignoreUnmapped) { + this.ignoreUnmapped = ignoreUnmapped; + return this; + } + + public Builder filter(QueryBuilder filter) { + this.filter = filter; + return this; } - validateSingleQueryType(k, maxDistance, minScore); - if (minScore <= 0) { - throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME)); + + public Builder queryName(String queryName) { + this.queryName = queryName; + return this; + } + + public Builder boost(float boost) { + this.boost = boost; + return this; + } + + public KNNQueryBuilder build() { + validate(); + int k = this.k == null ? 0 : this.k; + return new KNNQueryBuilder(fieldName, vector, k, maxDistance, minScore, methodParameters, filter, ignoreUnmapped).boost(boost) + .queryName(queryName); + } + + private void validate() { + if (Strings.isNullOrEmpty(fieldName)) { + throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); + } + + if (vector == null) { + throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME)); + } else if (vector.length == 0) { + throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME)); + } + + if (k == null && minScore == null && maxDistance == null) { + throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); + } + + if ((k != null && maxDistance != null) || (maxDistance != null && minScore != null) || (k != null && minScore != null)) { + throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); + } + + VectorQueryType vectorQueryType = VectorQueryType.MAX_DISTANCE; + if (k != null) { + vectorQueryType = VectorQueryType.K; + if (k <= 0 || k > K_MAX) { + throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, K_MAX)); + } + } + + if (minScore != null) { + vectorQueryType = VectorQueryType.MIN_SCORE; + if (minScore <= 0) { + throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME)); + } + } + + if (methodParameters != null) { + ValidationException validationException = validateMethodParameters(methodParameters); + if (validationException != null) { + throw new IllegalArgumentException( + String.format("[%s] errors in method parameter [%s]", NAME, validationException.getMessage()) + ); + } + } + + // Update stats + vectorQueryType.getQueryStatCounter().increment(); + if (filter != null) { + vectorQueryType.getQueryWithFilterStatCounter().increment(); + } } - this.minScore = minScore; - return this; } - /** - * Builder method for filter - * - * @param filter QueryBuilder - */ - public KNNQueryBuilder filter(QueryBuilder filter) { - this.filter = filter; - return this; + public static KNNQueryBuilder.Builder builder() { + return new KNNQueryBuilder.Builder(); } /** @@ -157,10 +256,12 @@ public KNNQueryBuilder filter(QueryBuilder filter) { * @param vector Array of floating points * @param k K nearest neighbours for the given vector */ + @Deprecated public KNNQueryBuilder(String fieldName, float[] vector, int k) { this(fieldName, vector, k, null); } + @Deprecated public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) { if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); @@ -225,6 +326,7 @@ public KNNQueryBuilder(StreamInput in) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { minScore = in.readOptionalFloat(); } + methodParameters = MethodParametersParser.streamInput(in); } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } @@ -241,6 +343,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String queryName = null; String currentFieldName = null; boolean ignoreUnmapped = false; + Map methodParameters = null; XContentParser.Token token; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -279,6 +382,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (FILTER_FIELD.getPreferredName().equals(tokenName)) { log.debug(String.format("Start parsing filter for field [%s]", fieldName)); filter = parseInnerQueryBuilder(parser); + } else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + methodParameters = MethodParametersParser.fromXContent(parser); } else { throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); } @@ -296,26 +401,18 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - VectorQueryType vectorQueryType = validateSingleQueryType(k, maxDistance, minScore); - vectorQueryType.getQueryStatCounter().increment(); - if (filter != null) { - vectorQueryType.getQueryWithFilterStatCounter().increment(); - } - - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) - .ignoreUnmapped(ignoreUnmapped) + return KNNQueryBuilder.builder() + .queryName(queryName) .boost(boost) - .queryName(queryName); - - if (k != null) { - knnQueryBuilder.k(k); - } else if (maxDistance != null) { - knnQueryBuilder.maxDistance(maxDistance); - } else if (minScore != null) { - knnQueryBuilder.minScore(minScore); - } - - return knnQueryBuilder; + .fieldName(fieldName) + .vector(ObjectsToFloats(vector)) + .k(k) + .maxDistance(maxDistance) + .minScore(minScore) + .methodParameters(methodParameters) + .ignoreUnmapped(ignoreUnmapped) + .filter(filter) + .build(); } @Override @@ -333,6 +430,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { out.writeOptionalFloat(minScore); } + MethodParametersParser.streamOutput(out, methodParameters); } /** @@ -349,36 +447,6 @@ public Object vector() { return this.vector; } - public int getK() { - return this.k; - } - - public float getMaxDistance() { - return this.maxDistance; - } - - public float getMinScore() { - return this.minScore; - } - - public QueryBuilder getFilter() { - return this.filter; - } - - /** - * Sets whether the query builder should ignore unmapped paths (and run a - * {@link MatchNoDocsQuery} in place of this query) or throw an exception if - * the path is unmapped. - */ - public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) { - this.ignoreUnmapped = ignoreUnmapped; - return this; - } - - public boolean getIgnoreUnmapped() { - return this.ignoreUnmapped; - } - @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); @@ -398,6 +466,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (minScore != null) { builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); } + if (methodParameters != null) { + MethodParametersParser.doXContent(builder, methodParameters); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -418,6 +489,7 @@ protected Query doToQuery(QueryShardContext context) { KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType; int fieldDimension = knnVectorFieldType.getDimension(); KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); + MethodComponentContext methodComponentContext = null; KNNEngine knnEngine = KNNEngine.DEFAULT; VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); SpaceType spaceType = knnVectorFieldType.getSpaceType(); @@ -431,10 +503,32 @@ protected Query doToQuery(QueryShardContext context) { fieldDimension = modelMetadata.getDimension(); knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); + methodComponentContext = modelMetadata.getMethodComponentContext(); + } else if (knnMethodContext != null) { // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping knnEngine = knnMethodContext.getKnnEngine(); spaceType = knnMethodContext.getSpaceType(); + methodComponentContext = knnMethodContext.getMethodComponentContext(); + } + + final String method = methodComponentContext != null ? methodComponentContext.getName() : null; + if (StringUtils.isNotBlank(method)) { + final EngineSpecificMethodContext engineSpecificMethodContext = knnEngine.getMethodContext(method); + ValidationException validationException = validateParameters( + engineSpecificMethodContext.supportedMethodParameters(), + (Map) methodParameters + ); + if (validationException != null) { + throw new IllegalArgumentException( + String.format( + "Parameters not valid for [%s]:[%s] combination: [%s]", + knnEngine, + method, + validationException.getMessage() + ) + ); + } } // Currently, k-NN supports distance and score types radial search @@ -493,6 +587,7 @@ protected Query doToQuery(QueryShardContext context) { .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) .vectorDataType(vectorDataType) .k(this.k) + .methodParameters(this.methodParameters) .filter(this.filter) .context(context) .build(); @@ -537,41 +632,20 @@ protected boolean doEquals(KNNQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Arrays.equals(vector, other.vector) && Objects.equals(k, other.k) + && Objects.equals(minScore, other.minScore) + && Objects.equals(maxDistance, other.maxDistance) + && Objects.equals(methodParameters, other.methodParameters) && Objects.equals(filter, other.filter) && Objects.equals(ignoreUnmapped, other.ignoreUnmapped); } @Override protected int doHashCode() { - return Objects.hash(fieldName, Arrays.hashCode(vector), k, filter, ignoreUnmapped); + return Objects.hash(fieldName, Arrays.hashCode(vector), k, methodParameters, filter, ignoreUnmapped, maxDistance, minScore); } @Override public String getWriteableName() { return NAME; } - - private static VectorQueryType validateSingleQueryType(Integer k, Float distance, Float score) { - int countSetFields = 0; - VectorQueryType vectorQueryType = null; - - if (k != null && k != 0) { - countSetFields++; - vectorQueryType = VectorQueryType.K; - } - if (distance != null) { - countSetFields++; - vectorQueryType = VectorQueryType.MAX_DISTANCE; - } - if (score != null) { - countSetFields++; - vectorQueryType = VectorQueryType.MIN_SCORE; - } - - if (countSetFields != 1) { - throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); - } - - return vectorQueryType; - } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index ec1f53d13..a25f941d7 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -17,6 +17,7 @@ import org.opensearch.knn.index.util.KNNEngine; import java.util.Locale; +import java.util.Map; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -71,6 +72,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { final byte[] byteVector = createQueryRequest.getByteVector(); final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); + final Map methodParameters = createQueryRequest.getMethodParameters(); BitSetProducer parentFilter = null; if (createQueryRequest.getContext().isPresent()) { @@ -79,12 +81,24 @@ public static Query create(CreateQueryRequest createQueryRequest) { } if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { - if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { - log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k); - return new KNNQuery(fieldName, vector, k, indexName, filterQuery, parentFilter); - } - log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KNNQuery(fieldName, vector, k, indexName, parentFilter); + final Query validatedFilterQuery = validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine()); + log.debug( + "Creating custom k-NN query for index:{}, field:{}, k:{}, filterQuery:{}, efSearch:{}", + indexName, + fieldName, + k, + validatedFilterQuery, + methodParameters + ); + return KNNQuery.builder() + .field(fieldName) + .queryVector(vector) + .indexName(indexName) + .parentsFilter(parentFilter) + .k(k) + .methodParameters(methodParameters) + .filterQuery(validatedFilterQuery) + .build(); } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); @@ -106,6 +120,14 @@ public static Query create(CreateQueryRequest createQueryRequest) { } } + private static Query validateFilterQuerySupport(final Query filterQuery, final KNNEngine knnEngine) { + log.debug("filter query {}, knnEngine {}", filterQuery, knnEngine); + if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) { + return filterQuery; + } + return null; + } + /** * If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery} * which will dedupe search result per parent so that we can get k parent results at the end. diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index bac8c03d4..794c9af1c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -283,6 +283,7 @@ private Map doANNSearch(final LeafReaderContext context, final B indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), + knnQuery.getMethodParameters(), knnEngine, filterIds, filterType.getValue(), diff --git a/src/main/java/org/opensearch/knn/index/query/parser/KNNXParserUtil.java b/src/main/java/org/opensearch/knn/index/query/parser/KNNXParserUtil.java new file mode 100644 index 000000000..9e83d76d7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/parser/KNNXParserUtil.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query.parser; + +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.core.xcontent.XContentParserUtils.parseFieldsValue; +import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME; + +public final class KNNXParserUtil { + + public static Map parseJsonObject(XContentParser parser) throws IOException { + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] Error parsing json. current token should be START_OBJECT" + ); + } + + String fieldName = null; + XContentParser.Token token; + final Map fieldToValueMap = new HashMap<>(); + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = parser.currentName(); + } else { + assert fieldName != null; + fieldToValueMap.put(fieldName, parseFieldsValue(parser)); + } + } + return fieldToValueMap; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java b/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java new file mode 100644 index 000000000..e3f22c3ed --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java @@ -0,0 +1,133 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query.parser; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.opensearch.common.ValidationException; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.index.query.request.MethodParameter; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; +import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME; +import static org.opensearch.knn.index.query.parser.KNNXParserUtil.parseJsonObject; + +@EqualsAndHashCode +@Getter +@AllArgsConstructor +public class MethodParametersParser { + + // Validation on rest layer + public static ValidationException validateMethodParameters(final Map methodParameters) { + final List errors = new ArrayList<>(); + for (final Map.Entry methodParameter : methodParameters.entrySet()) { + final MethodParameter parameter = MethodParameter.enumOf(methodParameter.getKey()); + if (parameter != null) { + final ValidationException validationException = parameter.validate(methodParameter.getValue()); + if (validationException != null) { + errors.add(validationException.getMessage()); + } + } else { // Should never happen if used in the right sequence + errors.add(methodParameter.getKey() + " is not a valid method parameter"); + } + } + + if (!errors.isEmpty()) { + ValidationException validationException = new ValidationException(); + validationException.addValidationErrors(errors); + return validationException; + } + return null; + } + + // deserialize for node to node communication + public static Map streamInput(StreamInput in) throws IOException { + if (!in.readBoolean()) { + return null; + } + + final Map methodParameters = new HashMap<>(); + for (final MethodParameter methodParameter : MethodParameter.values()) { + if (isClusterOnOrAfterMinRequiredVersion(methodParameter.getName())) { + String name = in.readString(); + Object value = in.readGenericValue(); + if (value != null) { + methodParameters.put(name, methodParameter.parse(value)); + } + } + } + + return !methodParameters.isEmpty() ? methodParameters : null; + } + + // serialize for node to node communication + public static void streamOutput(StreamOutput out, Map methodParameters) throws IOException { + if (methodParameters == null || methodParameters.isEmpty()) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + // All values are written to deserialize without ambiguity + for (final MethodParameter methodParameter : MethodParameter.values()) { + if (isClusterOnOrAfterMinRequiredVersion(methodParameter.getName())) { + out.writeString(methodParameter.getName()); + out.writeGenericValue(methodParameters.get(methodParameter.getName())); + } + } + } + } + + public static void doXContent(final XContentBuilder builder, final Map methodParameters) throws IOException { + if (methodParameters == null || methodParameters.isEmpty()) { + return; + } + builder.startObject(METHOD_PARAMS_FIELD.getPreferredName()); + for (final Map.Entry entry : methodParameters.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + builder.field(entry.getKey(), entry.getValue()); + } + } + builder.endObject(); + } + + public static Map fromXContent(final XContentParser parser) throws IOException { + final Map methodParametersJson = parseJsonObject(parser); + final Map methodParameters = new HashMap<>(); + for (Map.Entry requestParameter : methodParametersJson.entrySet()) { + final String name = requestParameter.getKey(); + final Object value = requestParameter.getValue(); + final MethodParameter parameter = MethodParameter.enumOf(name); + if (parameter == null) { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown method parameter found [" + name + "]"); + } + + try { + // This makes sure that we throw parsing exception on rest layer. + methodParameters.put(name, parameter.parse(value)); + } catch (final Exception exception) { + throw new ParsingException(parser.getTokenLocation(), exception.getMessage()); + } + } + return methodParameters.isEmpty() ? null : methodParameters; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java b/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java new file mode 100644 index 000000000..5c3aa5182 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query.request; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.Version; +import org.opensearch.common.ValidationException; +import org.opensearch.core.ParseField; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.index.query.KNNQueryBuilder.EF_SEARCH_FIELD; + +/** + * MethodParameters are engine and algorithm related parameters that clients can pass in knn query + * This enum holds metadata which helps parse and have basic validation related to MethodParameter + */ +@Getter +@RequiredArgsConstructor +public enum MethodParameter { + + // TODO: change the version to 2.16 when merging into 2.x + EF_SEARCH(METHOD_PARAMETER_EF_SEARCH, Version.CURRENT, EF_SEARCH_FIELD) { + @Override + public Integer parse(Object value) { + try { + return Integer.parseInt(String.valueOf(value)); + } catch (final NumberFormatException e) { + throw new IllegalArgumentException(METHOD_PARAMETER_EF_SEARCH + " value must be an integer"); + } + } + + @Override + public ValidationException validate(Object value) { + final Integer ef = parse(value); + if (ef != null && ef > 0) { + return null; + } + ; + ValidationException validationException = new ValidationException(); + validationException.addValidationError(METHOD_PARAMETER_EF_SEARCH + " should be greater than 0"); + return validationException; + } + }; + + private final String name; + private final Version version; + private final ParseField parseField; + + private static Map PARAMETERS_DIR; + + public abstract T parse(Object value); + + // These are preliminary validations on rest layer + public abstract ValidationException validate(Object value); + + public static MethodParameter enumOf(final String name) { + if (PARAMETERS_DIR == null) { + PARAMETERS_DIR = new HashMap<>(); + for (final MethodParameter methodParameter : MethodParameter.values()) { + PARAMETERS_DIR.put(methodParameter.name, methodParameter); + } + } + return PARAMETERS_DIR.get(name); + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java index 0fe311094..932eba598 100644 --- a/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java @@ -9,6 +9,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; import org.opensearch.common.ValidationException; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.training.VectorSpaceInfo; @@ -22,6 +23,7 @@ public abstract class AbstractKNNLibrary implements KNNLibrary { protected final Map methods; + protected final Map engineMethods; @Getter protected final String version; @@ -34,6 +36,15 @@ public KNNMethod getMethod(String methodName) { return method; } + @Override + public EngineSpecificMethodContext getMethodContext(String methodName) { + EngineSpecificMethodContext method = engineMethods.get(methodName); + if (method == null) { + throw new IllegalArgumentException(String.format("Invalid method name: %s", methodName)); + } + return method; + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { String methodName = knnMethodContext.getMethodComponentContext().getName(); diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index bbb58bf1e..0159ac0d0 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -8,6 +8,8 @@ import com.google.common.collect.ImmutableMap; import lombok.AllArgsConstructor; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; +import org.opensearch.knn.engine.method.DefaultHnswContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.MethodComponent; @@ -330,7 +332,13 @@ private Faiss( String extension, Map> scoreTransform ) { - super(methods, scoreTranslation, currentVersion, extension); + super( + methods, + Map.of(METHOD_HNSW, new DefaultHnswContext(), METHOD_IVF, EngineSpecificMethodContext.EMPTY), + scoreTranslation, + currentVersion, + extension + ); this.scoreTransform = scoreTransform; } diff --git a/src/main/java/org/opensearch/knn/index/util/JVMLibrary.java b/src/main/java/org/opensearch/knn/index/util/JVMLibrary.java index e1d48cb0a..235ccdb11 100644 --- a/src/main/java/org/opensearch/knn/index/util/JVMLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/JVMLibrary.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.util; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; @@ -23,8 +24,8 @@ public abstract class JVMLibrary extends AbstractKNNLibrary { * @param methods Map of k-NN methods that the library supports * @param version String representing version of library */ - JVMLibrary(Map methods, String version) { - super(methods, version); + JVMLibrary(Map methods, Map engineMethodMetadataMap, String version) { + super(methods, engineMethodMetadataMap, version); } @Override diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index 556785783..670d62d0e 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableSet; import org.opensearch.common.ValidationException; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; @@ -149,6 +150,11 @@ public KNNMethod getMethod(String methodName) { return knnLibrary.getMethod(methodName); } + @Override + public EngineSpecificMethodContext getMethodContext(String methodName) { + return knnLibrary.getMethodContext(methodName); + } + @Override public float score(float rawScore, SpaceType spaceType) { return knnLibrary.score(rawScore, spaceType); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index cac5af2bb..24abcd601 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -12,6 +12,7 @@ package org.opensearch.knn.index.util; import org.opensearch.common.ValidationException; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; @@ -58,6 +59,13 @@ public interface KNNLibrary { */ KNNMethod getMethod(String methodName); + /** + * Gets metadata related to methods supported by the library + * @param methodName + * @return + */ + EngineSpecificMethodContext getMethodContext(String methodName); + /** * Generate the Lucene score from the rawScore returned by the library. With k-NN, often times the library * will return a score where the lower the score, the better the result. This is the opposite of how Lucene scores diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index 630d7a2c2..1ba8aa158 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.lucene.util.Version; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.MethodComponent; @@ -67,7 +68,7 @@ public class Lucene extends JVMLibrary { * @param distanceTransform Map of space type to distance transformation function */ Lucene(Map methods, String version, Map> distanceTransform) { - super(methods, version); + super(methods, Map.of(METHOD_HNSW, EngineSpecificMethodContext.EMPTY), version); this.distanceTransform = distanceTransform; } diff --git a/src/main/java/org/opensearch/knn/index/util/NativeLibrary.java b/src/main/java/org/opensearch/knn/index/util/NativeLibrary.java index 5e264ed12..57836177e 100644 --- a/src/main/java/org/opensearch/knn/index/util/NativeLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/NativeLibrary.java @@ -7,6 +7,7 @@ import lombok.Getter; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; @@ -36,11 +37,12 @@ abstract class NativeLibrary extends AbstractKNNLibrary { */ NativeLibrary( Map methods, + Map engineMethods, Map> scoreTranslation, String version, String extension ) { - super(methods, version); + super(methods, engineMethods, version); this.scoreTranslation = scoreTranslation; this.extension = extension; this.initialized = new AtomicBoolean(false); diff --git a/src/main/java/org/opensearch/knn/index/util/Nmslib.java b/src/main/java/org/opensearch/knn/index/util/Nmslib.java index 64af43520..8677fae96 100644 --- a/src/main/java/org/opensearch/knn/index/util/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/util/Nmslib.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.util; import com.google.common.collect.ImmutableMap; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.MethodComponent; @@ -66,7 +67,7 @@ private Nmslib( String currentVersion, String extension ) { - super(methods, scoreTranslation, currentVersion, extension); + super(methods, Map.of(METHOD_HNSW, EngineSpecificMethodContext.EMPTY), scoreTranslation, currentVersion, extension); } @Override diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 53980bbb7..77b786421 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -129,7 +129,13 @@ public static native void createIndexFromTemplate( * @param parentIds list of parent doc ids when the knn field is a nested field * @return KNNQueryResult array of k neighbors */ - public static native KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, int[] parentIds); + public static native KNNQueryResult[] queryIndex( + long indexPointer, + float[] queryVector, + int k, + Map methodParameters, + int[] parentIds + ); /** * Query an index with filter @@ -145,6 +151,7 @@ public static native KNNQueryResult[] queryIndexWithFilter( long indexPointer, float[] queryVector, int k, + Map methodParameters, long[] filterIds, int filterIdsType, int[] parentIds diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 20c418819..cc1f9be3c 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -12,6 +12,7 @@ package org.opensearch.knn.jni; import org.apache.commons.lang.ArrayUtils; +import org.opensearch.common.Nullable; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -173,6 +174,7 @@ public static KNNQueryResult[] queryIndex( long indexPointer, float[] queryVector, int k, + @Nullable Map methodParameters, KNNEngine knnEngine, long[] filteredIds, int filterIdsType, @@ -188,9 +190,17 @@ public static KNNQueryResult[] queryIndex( // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine // normally. if (ArrayUtils.isNotEmpty(filteredIds)) { - return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, filterIdsType, parentIds); + return FaissService.queryIndexWithFilter( + indexPointer, + queryVector, + k, + methodParameters, + filteredIds, + filterIdsType, + parentIds + ); } - return FaissService.queryIndex(indexPointer, queryVector, k, parentIds); + return FaissService.queryIndex(indexPointer, queryVector, k, methodParameters, parentIds); } throw new IllegalArgumentException(String.format("QueryIndex not supported for provided engine : %s", knnEngine.getName())); } diff --git a/src/main/java/org/opensearch/knn/validation/ParameterValidator.java b/src/main/java/org/opensearch/knn/validation/ParameterValidator.java new file mode 100644 index 000000000..15925fffa --- /dev/null +++ b/src/main/java/org/opensearch/knn/validation/ParameterValidator.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.validation; + +import org.opensearch.common.Nullable; +import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.Parameter; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public final class ParameterValidator { + + /** + * A function which validates request parameters. + * @param validParameters A set of valid parameters that can be requestParameters can be validated against + * @param requestParameters parameters from the request + * @return + */ + @Nullable + public static ValidationException validateParameters( + final Map> validParameters, + final Map requestParameters + ) { + + if (validParameters == null) { + throw new IllegalArgumentException("validParameters cannot be null"); + } + + if (requestParameters == null || requestParameters.isEmpty()) { + return null; + } + + final List errorMessages = new ArrayList<>(); + for (Map.Entry parameter : requestParameters.entrySet()) { + if (validParameters.containsKey(parameter.getKey())) { + final ValidationException parameterValidation = validParameters.get(parameter.getKey()).validate(parameter.getValue()); + if (parameterValidation != null) { + errorMessages.addAll(parameterValidation.validationErrors()); + } + } else { + errorMessages.add("Unknown parameter '" + parameter.getKey() + "' found"); + } + } + + if (errorMessages.isEmpty()) { + return null; + } + + final ValidationException validationException = new ValidationException(); + validationException.addValidationErrors(errorMessages); + return validationException; + } +} diff --git a/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java b/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java new file mode 100644 index 000000000..ae8788d17 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java @@ -0,0 +1,194 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Floats; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.BeforeClass; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.plugin.script.KNNScoringUtil; + +import java.io.IOException; +import java.net.URL; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; + +@AllArgsConstructor +public class FaissHNSWFlatE2EIT extends KNNRestTestCase { + + private String description; + private int k; + private Map methodParameters; + private boolean deleteRandomDocs; + + static TestUtils.TestData testData; + + @BeforeClass + public static void setUpClass() throws IOException { + if (FaissHNSWFlatE2EIT.class.getClassLoader() == null) { + throw new IllegalStateException("ClassLoader of FaissIT Class is null"); + } + URL testIndexVectors = FaissHNSWFlatE2EIT.class.getClassLoader().getResource("data/test_vectors_1000x128.json"); + URL testQueries = FaissHNSWFlatE2EIT.class.getClassLoader().getResource("data/test_queries_100x128.csv"); + assert testIndexVectors != null; + assert testQueries != null; + testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); + } + + @ParametersFactory(argumentFormatting = "description:%1$s; k:%2$s; efSearch:%3$s, deleteDocs:%4$s") + public static Collection parameters() { + return Arrays.asList( + $$( + $("Valid k, valid efSearch efSearch value", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), false), + $("Valid k, efsearch absent", 10, null, false), + $("Has delete docs, ef_search", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true), + $("Has delete docs", 10, null, true) + ) + ); + } + + @SneakyThrows + public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { + String indexName = "test-index-1"; + String fieldName = "test-field-1"; + + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + // Delete few Docs + if (deleteRandomDocs) { + final Set docIdsToBeDeleted = new HashSet<>(); + while (docIdsToBeDeleted.size() < 10) { + docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length - 1)); + } + + for (Integer id : docIdsToBeDeleted) { + deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id])); + } + refreshAllNonSystemIndices(); + forceMergeKnnIndex(indexName, 3); + + assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName)); + } + + // Test search queries + for (int i = 0; i < testData.queries.length; i++) { + final KNNQueryBuilder queryBuilder = KNNQueryBuilder.builder() + .fieldName(fieldName) + .vector(testData.queries[i]) + .k(k) + .methodParameters(methodParameters) + .build(); + Response response = searchKNNIndex(indexName, queryBuilder, k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + } + + // Delete index + deleteKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } +} diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index b018740bc..6cf5a3177 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -37,12 +37,10 @@ import java.net.URL; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; -import java.util.Set; import java.util.TreeMap; import java.util.stream.Collectors; @@ -90,197 +88,6 @@ public static void setUpClass() throws IOException { testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); } - @SneakyThrows - public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { - String indexName = "test-index-1"; - String fieldName = "test-field-1"; - - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); - SpaceType spaceType = SpaceType.L2; - - List mValues = ImmutableList.of(16, 32, 64, 128); - List efConstructionValues = ImmutableList.of(16, 32, 64, 128); - List efSearchValues = ImmutableList.of(16, 32, 64, 128); - - Integer dimension = testData.indexData.vectors[0].length; - - // Create an index - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(NAME, hnswMethod.getMethodComponent().getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); - - Map mappingMap = xContentBuilderToMap(builder); - String mapping = builder.toString(); - - createKnnIndex(indexName, mapping); - assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); - - // Index the test data - for (int i = 0; i < testData.indexData.docs.length; i++) { - addKnnDoc( - indexName, - Integer.toString(testData.indexData.docs[i]), - fieldName, - Floats.asList(testData.indexData.vectors[i]).toArray() - ); - } - - // Assert we have the right number of documents in the index - refreshAllNonSystemIndices(); - assertEquals(testData.indexData.docs.length, getDocCount(indexName)); - - int k = 10; - for (int i = 0; i < testData.queries.length; i++) { - Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); - String responseBody = EntityUtils.toString(response.getEntity()); - List knnResults = parseSearchResponse(responseBody, fieldName); - assertEquals(k, knnResults.size()); - - List actualScores = parseSearchResponseScore(responseBody, fieldName); - for (int j = 0; j < k; j++) { - float[] primitiveArray = knnResults.get(j).getVector(); - assertEquals( - KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), - actualScores.get(j), - 0.0001 - ); - } - } - - // Delete index - deleteKNNIndex(indexName); - - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); - } - - @SneakyThrows - public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { - String indexName = "test-index-1"; - String fieldName = "test-field-1"; - - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); - SpaceType spaceType = SpaceType.L2; - - List mValues = ImmutableList.of(16, 32, 64, 128); - List efConstructionValues = ImmutableList.of(16, 32, 64, 128); - List efSearchValues = ImmutableList.of(16, 32, 64, 128); - - Integer dimension = testData.indexData.vectors[0].length; - - // Create an index - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(NAME, hnswMethod.getMethodComponent().getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); - - Map mappingMap = xContentBuilderToMap(builder); - String mapping = builder.toString(); - - createKnnIndex(indexName, mapping); - assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); - - // Index the test data - for (int i = 0; i < testData.indexData.docs.length; i++) { - addKnnDoc( - indexName, - Integer.toString(testData.indexData.docs[i]), - fieldName, - Floats.asList(testData.indexData.vectors[i]).toArray() - ); - } - - // Assert we have the right number of documents in the index - refreshAllNonSystemIndices(); - assertEquals(testData.indexData.docs.length, getDocCount(indexName)); - - final Set docIdsToBeDeleted = new HashSet<>(); - while (docIdsToBeDeleted.size() < 10) { - docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length - 1)); - } - - for (Integer id : docIdsToBeDeleted) { - deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id])); - } - refreshAllNonSystemIndices(); - forceMergeKnnIndex(indexName, 3); - - assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName)); - - int k = 10; - for (int i = 0; i < testData.queries.length; i++) { - Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); - String responseBody = EntityUtils.toString(response.getEntity()); - List knnResults = parseSearchResponse(responseBody, fieldName); - assertEquals(k, knnResults.size()); - - List actualScores = parseSearchResponseScore(responseBody, fieldName); - for (int j = 0; j < k; j++) { - float[] primitiveArray = knnResults.get(j).getVector(); - assertEquals( - KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), - actualScores.get(j), - 0.0001 - ); - } - } - - // Delete index - deleteKNNIndex(indexName); - - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); - } - @SneakyThrows public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWFlat_thenSucceed() { KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 2ce3a7c83..f7c9f3eb8 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -59,6 +59,7 @@ import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; @@ -70,6 +71,9 @@ public class KNN80DocValuesConsumerTests extends KNNTestCase { + private static final int EF_SEARCH = 10; + private static final Map HNSW_METHODPARAMETERS = Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH); + private static Directory directory; private static Codec codec; @@ -202,7 +206,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by nmslib - assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); @@ -255,7 +259,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by nmslib - assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); @@ -316,7 +320,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by faiss - assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); @@ -411,7 +415,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio assertValidFooter(state.directory, expectedFile); // The document should be readable by faiss - assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index a60157580..fe8200375 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -326,6 +326,7 @@ public static void assertValidFooter(Directory dir, String filename) throws IOEx } public static void assertLoadableByEngine( + Map methodParameters, SegmentWriteState state, String fileName, KNNEngine knnEngine, @@ -337,7 +338,7 @@ public static void assertLoadableByEngine( long indexPtr = JNIService.loadIndex(filePath, Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())), knnEngine); int k = 2; float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, methodParameters, knnEngine, null, 0, null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index a84974202..876303523 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -76,7 +76,7 @@ public void testIndexLoadStrategy_load() throws IOException { // Confirm that the file was loaded by querying float[] query = new float[dimension]; Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, null, knnEngine, null, 0, null); assertTrue(results.length > 0); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java new file mode 100644 index 000000000..c7084210e --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.AllArgsConstructor; +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; + +@AllArgsConstructor +public class KNNQueryBuilderInvalidParamsTests extends KNNTestCase { + + private static final float[] QUERY_VECTOR = new float[] { 1.2f, 2.3f, 4.5f }; + private static final String FIELD_NAME = "test_vector"; + + private String description; + private String expectedMessage; + private KNNQueryBuilder.Builder knnQueryBuilderBuilder; + + @ParametersFactory(argumentFormatting = "description:%1$s; expectedMessage:%2$s; querybuilder:%3$s") + public static Collection invalidParameters() { + return Arrays.asList( + $$( + $("fieldName absent", "[knn] requires fieldName", KNNQueryBuilder.builder().k(1).vector(QUERY_VECTOR)), + $("vector absent", "[knn] requires query vector", KNNQueryBuilder.builder().k(1).fieldName(FIELD_NAME)), + $( + "vector empty", + "[knn] query vector is empty", + KNNQueryBuilder.builder().k(1).fieldName(FIELD_NAME).vector(new float[] {}) + ), + $( + "Neither knn nor radial search", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR) + ), + $( + "max distance and k present", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(1).maxDistance(10f) + ), + $( + "min_score and k present", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(1).minScore(1.0f) + ), + $( + "max_dist and min_score present", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).maxDistance(1.0f).minScore(1.0f) + ), + $( + "max_dist, k and min_score present", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(1).maxDistance(1.0f).minScore(1.0f) + ), + $( + "-ve k value", + "[knn] requires k to be in the range (0, 10000]", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(-1) + ), + $( + "k value greater than max", + "[knn] requires k to be in the range (0, 10000]", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(10001) + ), + $( + "efSearch 0", + "[knn] errors in method parameter [Validation Failed: 1: Validation Failed: 1: ef_search should be greater than 0;;]", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).methodParameters(Map.of("ef_search", 0)).k(10) + ), + $( + "efSearch -ve", + "[knn] errors in method parameter [Validation Failed: 1: Validation Failed: 1: ef_search should be greater than 0;;]", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).methodParameters(Map.of("ef_search", -10)).k(10) + ), + $( + "min score less than 0", + "[knn] requires minScore to be greater than 0", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(-1f) + ) + ) + ); + } + + public void testInvalidBuilder() { + Throwable exception = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilderBuilder.build()); + assertEquals(expectedMessage, exception.getMessage(), expectedMessage); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 09aaaa6b9..499909a13 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -7,29 +7,29 @@ import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.FloatVectorSimilarityQuery; -import java.util.Locale; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Nullable; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.index.Index; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexSettings; +import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.core.index.Index; -import org.opensearch.index.mapper.NumberFieldMapper; -import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; @@ -46,22 +46,26 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import static java.util.Collections.emptyMap; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; +import static org.opensearch.knn.index.query.KNNQueryBuilder.EF_SEARCH_FIELD; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; + private static final int EF_SEARCH = 10; + private static final Map HNSW_METHOD_PARAMS = Map.of("ef_search", EF_SEARCH); private static final Float MAX_DISTANCE = 1.0f; private static final Float MIN_SCORE = 0.5f; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); @@ -91,7 +95,10 @@ public void testInvalidDistance() { /** * null distance */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(null)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).maxDistance(null).build() + ); } public void testInvalidScore() { @@ -99,17 +106,26 @@ public void testInvalidScore() { /** * null min_score */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(null)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(null).build() + ); /** * negative min_score */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(-1.0f)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(-1.0f).build() + ); /** * min_score = 0 */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(0.0f)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(0.0f).build() + ); } public void testEmptyVector() { @@ -129,13 +145,19 @@ public void testEmptyVector() { * null query vector with distance */ float[] queryVector2 = null; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector2).maxDistance(MAX_DISTANCE)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector2).maxDistance(MAX_DISTANCE).build() + ); /** * empty query vector with distance */ float[] queryVector3 = {}; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector3).maxDistance(MAX_DISTANCE)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector3).maxDistance(MAX_DISTANCE).build() + ); } public void testFromXContent() throws Exception { @@ -154,9 +176,37 @@ public void testFromXContent() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } + public void testFromXContent_KnnWithMethodParameters() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -172,12 +222,16 @@ public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSuccee public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MAX_DISTANCE); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .minScore(MAX_DISTANCE) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -208,6 +262,32 @@ public void testFromXContent_withFilter() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } + public void testFromXContent_KnnWithEfSearch_withFilter() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); @@ -215,7 +295,13 @@ public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_ knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .filter(TERM_QUERY) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -237,12 +323,17 @@ public void testFromXContent_wenDoRadiusSearch_whenScoreThreshold_whenFilter_the knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .minScore(MIN_SCORE) + .filter(TERM_QUERY) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); builder.endObject(); builder.endObject(); @@ -384,7 +475,11 @@ public void testDoToQuery_Normal() throws Exception { public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -392,7 +487,10 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); @@ -400,13 +498,19 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th assertTrue(query.toString().contains("resultSimilarity=" + resultSimilarity)); assertTrue( - query.toString().contains("traversalSimilarity=" + DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity) + query.toString() + .contains( + "traversalSimilarity=" + + org.opensearch.knn.common.KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity + ) ); } public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(MIN_SCORE).build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -414,7 +518,10 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); @@ -424,7 +531,12 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(negativeDistance) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -432,7 +544,10 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) ); @@ -448,7 +563,12 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(negativeDistance) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -456,7 +576,10 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) ); @@ -470,7 +593,8 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(score).build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -478,7 +602,10 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) ); @@ -494,7 +621,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(score).build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -502,7 +629,10 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) ); @@ -516,7 +646,12 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(negativeDistance) + .build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -524,7 +659,10 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) ); @@ -540,7 +678,13 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(negativeDistance) + .build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -548,7 +692,10 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) ); @@ -559,9 +706,15 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } - public void testDoToQuery_KnnQueryWithFilter() throws Exception { + public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { + // Given float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -569,25 +722,42 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + // When Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + + // Then assertNotNull(query); assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .filter(TERM_QUERY) + .build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -598,14 +768,22 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .filter(TERM_QUERY) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -615,21 +793,58 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS } public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { + // Given float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + // When + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + + // Then assertNotNull(query); assertTrue(query.getClass().isAssignableFrom(KNNQuery.class)); + assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters()); + } + + public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() { + + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of())) + ); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .methodParameters(Map.of("ef_search", 10)) + .build(); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() { @@ -641,7 +856,10 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -669,6 +887,7 @@ public void testDoToQuery_FromModel() { when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -682,7 +901,13 @@ public void testDoToQuery_FromModel() { public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -699,6 +924,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -715,7 +941,9 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(MIN_SCORE).build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -732,6 +960,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -815,27 +1044,41 @@ public void testDoToQuery_InvalidZeroByteVector() { public void testSerialization() throws Exception { // For k-NN search - assertSerialization(Version.CURRENT, Optional.empty(), K, null, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, EF_SEARCH, EF_SEARCH, null, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, EF_SEARCH, EF_SEARCH, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, EF_SEARCH, null, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null, null); // For distance threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, MAX_DISTANCE, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MAX_DISTANCE, null); // For score threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, MIN_SCORE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, null, MIN_SCORE); } private void assertSerialization( final Version version, final Optional queryBuilderOptional, Integer k, + @Nullable Integer requestEfSearch, + Integer expectedEfSearch, Float distance, Float score ) throws Exception { - final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance, score); + Map methodParameters = requestEfSearch == null ? null : Map.of("ef_search", requestEfSearch); + + final KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .maxDistance(distance) + .minScore(score) + .k(k) + .methodParameters(methodParameters) + .filter(queryBuilderOptional.orElse(null)) + .build(); final ClusterService clusterService = mockClusterService(version); @@ -856,6 +1099,13 @@ private void assertSerialization( assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); if (k != null) { assertEquals(k.intValue(), deserializedKnnQueryBuilder.getK()); + Integer actualEfSearch = methodParameters == null ? null : (Integer) methodParameters.get("ef_search"); + // Verifies efSearch + if (version.onOrAfter(Version.V_3_0_0)) { + assertEquals(expectedEfSearch, actualEfSearch); + } else { + assertNull(deserializedKnnQueryBuilder.getMethodParameters()); + } } else if (distance != null) { assertEquals(distance.floatValue(), deserializedKnnQueryBuilder.getMaxDistance(), 0.0f); } else { @@ -871,36 +1121,19 @@ private void assertSerialization( } } - private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBuilderOptional, Integer k, Float distance, Float score) { - final KNNQueryBuilder knnQueryBuilder; - if (k != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k, queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k); - } else if (distance != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance).filter(queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance); - } else if (score != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score).filter(queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score); - } else { - throw new IllegalArgumentException("Either k or distance must be provided"); - } - return knnQueryBuilder; - } - public void testIgnoreUnmapped() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); - knnQueryBuilder.ignoreUnmapped(true); - assertTrue(knnQueryBuilder.getIgnoreUnmapped()); - Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class)); + KNNQueryBuilder.Builder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .ignoreUnmapped(true); + assertTrue(knnQueryBuilder.build().isIgnoreUnmapped()); + Query query = knnQueryBuilder.build().doToQuery(mock(QueryShardContext.class)); assertNotNull(query); assertThat(query, instanceOf(MatchNoDocsQuery.class)); knnQueryBuilder.ignoreUnmapped(false); - expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class))); + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.build().doToQuery(mock(QueryShardContext.class))); } public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { @@ -911,9 +1144,15 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, SpaceType.L2, - new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()) + new MethodComponentContext(org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of()) ); - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(MAX_DISTANCE); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .maxDistance(MAX_DISTANCE) + .build(); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java new file mode 100644 index 000000000..4b97df4b4 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.AllArgsConstructor; +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; + +@AllArgsConstructor +public class KNNQueryBuilderValidParamsTests extends KNNTestCase { + + private static final float[] QUERY_VECTOR = new float[] { 1.2f, 2.3f, 4.5f }; + private static final String FIELD_NAME = "test_vector"; + + private String description; + private KNNQueryBuilder expected; + private Integer k; + private Map methodParameters; + private Float maxDistance; + private Float minScore; + + @ParametersFactory(argumentFormatting = "description:%1$s; k:%3$s, efSearch:%4$s, maxDist:%5$s, minScore:%6$s") + public static Collection validParameters() { + return Arrays.asList( + $$( + $( + "valid knn with k", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(10).build(), + 10, + null, + null, + null + ), + $( + "valid knn with k and efSearch", + KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .k(10) + .methodParameters(Map.of("ef_search", 12)) + .build(), + 10, + Map.of("ef_search", 12), + null, + null + ), + $( + "valid knn with maxDis", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).maxDistance(10.0f).build(), + null, + null, + 10.0f, + null + ), + $( + "valid knn with minScore", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(10.0f).build(), + null, + null, + null, + 10.0f + ) + ) + ); + } + + public void testValidBuilder() { + assertEquals( + expected, + KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .k(k) + .methodParameters(methodParameters) + .maxDistance(maxDistance) + .minScore(minScore) + .build() + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 1bb17cfae..25777460f 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -27,12 +27,14 @@ import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; public class KNNQueryFactoryTests extends KNNTestCase { private static final String FILTER_FILED_NAME = "foo"; @@ -45,6 +47,7 @@ public class KNNQueryFactoryTests extends KNNTestCase { private final String testIndexName = "test-index"; private final String testFieldName = "test-field"; private final int testK = 10; + private final Map methodParameters = Map.of(METHOD_PARAMETER_EF_SEARCH, 100); public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { @@ -106,28 +109,71 @@ public void testCreateLuceneQueryWithFilter() { } public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { + // Given final KNNEngine knnEngine = KNNEngine.FAISS; final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); MappedFieldType testMapper = mock(MappedFieldType.class); when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY); + + final KNNQuery expectedQuery = KNNQuery.builder() + .indexName(testIndexName) + .filterQuery(FILTER_QUERY) + .field(testFieldName) + .queryVector(testQueryVector) + .k(testK) + .methodParameters(methodParameters) + .build(); + + // When final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(testIndexName) .fieldName(testFieldName) .vector(testQueryVector) .k(testK) + .methodParameters(methodParameters) .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); - final Query query = KNNQueryFactory.create(createQueryRequest); - assertTrue(query instanceof KNNQuery); - assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); - assertEquals(testFieldName, ((KNNQuery) query).getField()); - assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); - assertEquals(testK, ((KNNQuery) query).getK()); - assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery()); + final Query actual = KNNQueryFactory.create(createQueryRequest); + + // Then + assertEquals(expectedQuery, actual); + } + + public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSuccess() { + // Given + final KNNEngine knnEngine = KNNEngine.FAISS; + final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY); + + final KNNQuery expectedQuery = KNNQuery.builder() + .indexName(testIndexName) + .filterQuery(FILTER_QUERY) + .field(testFieldName) + .queryVector(testQueryVector) + .k(testK) + .build(); + + // When + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + + final Query actual = KNNQueryFactory.create(createQueryRequest); + + // Then + assertEquals(expectedQuery, actual); } public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() { diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index adf985b39..abba1f491 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -67,11 +67,13 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; @@ -83,6 +85,8 @@ public class KNNWeightTests extends KNNTestCase { private static final Set SEGMENT_FILES_NMSLIB = Set.of("_0.cfe", "_0_2011_target_field.hnswc"); private static final Set SEGMENT_FILES_FAISS = Set.of("_0.cfe", "_0_2011_target_field.faissc"); private static final String CIRCUIT_BREAKER_LIMIT_100KB = "100Kb"; + private static final Integer EF_SEARCH = 10; + private static final Map HNSW_METHOD_PARAMETERS = Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH); private static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); private static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); @@ -159,7 +163,7 @@ public void testQueryScoreForFaissWithModel() { SpaceType spaceType = SpaceType.L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), eq(K), isNull(), any(), any(), anyInt(), any())) .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -303,7 +307,7 @@ public void testShardWithoutFiles() { @SneakyThrows public void testEmptyQueryResults() { final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {}; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), eq(K), isNull(), any(), any(), anyInt(), any())) .thenReturn(knnQueryResults); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -346,6 +350,7 @@ public void testEmptyQueryResults() { @SneakyThrows public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { + // Given int k = 3; final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); @@ -353,7 +358,16 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { filterBitSet.set(docId); } jniServiceMockedStatic.when( - () -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), eq(filterBitSet.getBits()), anyInt(), any()) + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(filterBitSet.getBits()), + anyInt(), + any() + ) ).thenReturn(getFilteredKNNQueryResults()); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -366,7 +380,15 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { when(liveDocsBits.length()).thenReturn(1000); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -406,15 +428,26 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(fieldInfo.attributes()).thenReturn(attributesMap); + // When final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); - assertNotNull(knnScorer); + // Then + assertNotNull(knnScorer); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); jniServiceMockedStatic.verify( - () -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), eq(filterBitSet.getBits()), anyInt(), any()) + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(filterBitSet.getBits()), + anyInt(), + any() + ) ); final List actualDocIds = new ArrayList<>(); @@ -677,17 +710,47 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { // Prepare query and weight when(bitSetProducer.getBitSet(leafReaderContext)).thenReturn(bitset); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, 1, INDEX_NAME, null, bitSetProducer); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(1) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .parentsFilter(bitSetProducer) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, null); - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), eq(parentsFilter))) - .thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(1), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + eq(parentsFilter) + ) + ).thenReturn(getKNNQueryResults()); // Execute Scorer knnScorer = knnWeight.scorer(leafReaderContext); // Verify - jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), eq(parentsFilter))); + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(1), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + eq(parentsFilter) + ) + ); assertNotNull(knnScorer); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); @@ -811,10 +874,17 @@ private void testQueryScore( final Set segmentFiles, final Map fileAttributes ) throws IOException { - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), any())) - .thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(K), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()) + ).thenReturn(getKNNQueryResults()); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(K) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); final float boost = (float) randomDoubleBetween(0, 10, true); final KNNWeight knnWeight = new KNNWeight(query, boost); diff --git a/src/test/java/org/opensearch/knn/index/query/parser/MethodParametersParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/MethodParametersParserTests.java new file mode 100644 index 000000000..e9323f27f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/parser/MethodParametersParserTests.java @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query.parser; + +import lombok.SneakyThrows; +import org.opensearch.common.ValidationException; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.KNNTestCase; + +import java.util.Map; + +import static org.opensearch.knn.index.query.parser.MethodParametersParser.doXContent; +import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; + +public class MethodParametersParserTests extends KNNTestCase { + + public void testValidateMethodParameters() { + ValidationException validationException = validateMethodParameters(Map.of("dummy", 0)); + assertEquals("Validation Failed: 1: dummy is not a valid method parameter;", validationException.getMessage()); + + ValidationException validationException2 = validateMethodParameters(Map.of("ef_search", 0)); + assertTrue(validationException2.getMessage().contains("Validation Failed: 1: ef_search should be greater than 0")); + + ValidationException validationException3 = validateMethodParameters(Map.of("ef_search", 10)); + assertNull(validationException3); + } + + @SneakyThrows + public void testDoXContent() { + Map params = Map.of("ef_search", 10); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("method_parameters") + .field("ef_search", 10) + .endObject() + .endObject(); + + XContentBuilder builder2 = XContentFactory.jsonBuilder().startObject(); + doXContent(builder2, params); + builder2.endObject(); + assertEquals(builder.toString(), builder2.toString()); + + XContentBuilder b3 = XContentFactory.jsonBuilder(); + XContentBuilder b4 = XContentFactory.jsonBuilder(); + + doXContent(b4, null); + assertEquals(b3.toString(), b4.toString()); + } + + @SneakyThrows + public void testFromXContent() { + // efsearch string + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("ef_search", "string").endObject(); + XContentParser parser1 = createParser(builder); + parser1.nextToken(); + expectThrows(ParsingException.class, () -> MethodParametersParser.fromXContent(parser1)); + + // unknown method parameter + builder = XContentFactory.jsonBuilder().startObject().field("unknown", "10").endObject(); + XContentParser parser2 = createParser(builder); + parser2.nextToken(); + expectThrows(ParsingException.class, () -> MethodParametersParser.fromXContent(parser2)); + + // Valid + builder = XContentFactory.jsonBuilder().startObject().field("ef_search", 10).endObject(); + XContentParser parser3 = createParser(builder); + parser3.nextToken(); + assertEquals(Map.of("ef_search", 10), MethodParametersParser.fromXContent(parser3)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java index 9e6bd67ea..1e6238d89 100644 --- a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java @@ -11,10 +11,12 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponent; import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.Parameter; import org.opensearch.knn.index.SpaceType; import java.io.IOException; @@ -77,6 +79,20 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { assertNotNull(testAbstractKNNLibrary2.validateMethod(knnMethodContext2)); } + public void testEngineSpecificMethods() throws IOException { + String methodName1 = "test-method-1"; + EngineSpecificMethodContext context = () -> Map.of("myparameter", new Parameter.BooleanParameter("myparameter", false, o -> o)); + + TestAbstractKNNLibrary testAbstractKNNLibrary1 = new TestAbstractKNNLibrary( + Collections.emptyMap(), + Map.of(methodName1, context), + "" + ); + + assertNotNull(testAbstractKNNLibrary1.getMethodContext(methodName1)); + assertTrue(testAbstractKNNLibrary1.getMethodContext(methodName1).supportedMethodParameters().containsKey("myparameter")); + } + public void testGetMethodAsMap() { String methodName = "test-method-1"; SpaceType spaceType = SpaceType.DEFAULT; @@ -109,7 +125,15 @@ public void testGetMethodAsMap() { private static class TestAbstractKNNLibrary extends AbstractKNNLibrary { public TestAbstractKNNLibrary(Map methods, String currentVersion) { - super(methods, currentVersion); + super(methods, Collections.emptyMap(), currentVersion); + } + + public TestAbstractKNNLibrary( + Map methods, + Map engineSpecificMethodContextMap, + String currentVersion + ) { + super(methods, engineSpecificMethodContextMap, currentVersion); } @Override diff --git a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java index 3c3afbee6..814712560 100644 --- a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java @@ -62,7 +62,7 @@ public TestNativeLibrary( String currentVersion, String extension ) { - super(methods, scoreTranslation, currentVersion, extension); + super(methods, Collections.emptyMap(), scoreTranslation, currentVersion, extension); } @Override diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index d6ae13e92..e71930d48 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -527,6 +527,7 @@ public void testQueryIndex_faiss_sqfp16_valid() { String sqfp16IndexDescription = "HNSW16,SQfp16"; int k = 10; + Map methodParameters = Map.of("ef_search", 12); float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); long memoryAddress = JNICommons.storeVectorData(0, truncatedVectors, (long) truncatedVectors.length * truncatedVectors[0].length); Path tmpFile = createTempFile(); @@ -544,13 +545,22 @@ public void testQueryIndex_faiss_sqfp16_valid() { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, methodParameters, KNNEngine.FAISS, null, 0, null); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, new long[] { 0 }, 0, null); + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + methodParameters, + KNNEngine.FAISS, + new long[] { 0 }, + 0, + null + ); assertEquals(0, results.length); } } @@ -736,12 +746,15 @@ public void testLoadIndex_faiss_valid() throws IOException { } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.LUCENE, null, 0, null)); + expectThrows( + IllegalArgumentException.class, + () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.LUCENE, null, 0, null) + ); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.NMSLIB, null, 0, null)); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { @@ -765,7 +778,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.NMSLIB, null, 0, null)); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -792,7 +805,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, null, KNNEngine.NMSLIB, null, 0, null); assertEquals(k, results.length); } } @@ -800,7 +813,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.FAISS, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.FAISS, null, 0, null)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { @@ -820,12 +833,13 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.FAISS, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.FAISS, null, 0, null)); } public void testQueryIndex_faiss_valid() throws IOException { int k = 10; + int efSearch = 100; List methods = ImmutableList.of(faissMethod); List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); @@ -850,13 +864,31 @@ public void testQueryIndex_faiss_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + null + ); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, new long[] { 0 }, 0, null); + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + new long[] { 0 }, + 0, + null + ); assertEquals(0, results.length); } } @@ -866,6 +898,7 @@ public void testQueryIndex_faiss_valid() throws IOException { public void testQueryIndex_faiss_parentIds() throws IOException { int k = 100; + int efSearch = 100; List methods = ImmutableList.of(faissMethod); List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); @@ -892,7 +925,16 @@ public void testQueryIndex_faiss_parentIds() throws IOException { assertNotEquals(0, pointer); for (float[] query : testDataNested.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, null, 0, parentIds); + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + parentIds + ); // Verify there is no more than one result from same parent Set parentIdSet = toParentIdSet(results, idToParentIdMap); assertEquals(results.length, parentIdSet.size()); @@ -1223,7 +1265,7 @@ private void assertQueryResultsMatch(float[][] testQueries, int k, List in for (float[] query : testQueries) { KNNQueryResult[][] allResults = new KNNQueryResult[indexAddresses.size()][]; for (int i = 0; i < indexAddresses.size(); i++) { - allResults[i] = JNIService.queryIndex(indexAddresses.get(i), query, k, KNNEngine.FAISS, null, 0, null); + allResults[i] = JNIService.queryIndex(indexAddresses.get(i), query, k, null, KNNEngine.FAISS, null, 0, null); assertEquals(k, allResults[i].length); } diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 46240e830..ccbe96a32 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -12,6 +12,7 @@ package org.opensearch.knn.plugin.stats.suppliers; import org.opensearch.common.ValidationException; +import org.opensearch.knn.engine.method.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; @@ -59,6 +60,11 @@ public KNNMethod getMethod(String methodName) { return null; } + @Override + public EngineSpecificMethodContext getMethodContext(String methodName) { + return null; + } + @Override public float score(float rawScore, SpaceType spaceType) { return 0; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index fa6a13f2f..ad03fd216 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1122,12 +1122,22 @@ public void addKNNDocs(String testIndex, String testField, int dimension, int fi } } - // Validate KNN search on a KNN index by generating the query vector from the number of documents in the index public void validateKNNSearch(String testIndex, String testField, int dimension, int numDocs, int k) throws Exception { + validateKNNSearch(testIndex, testField, dimension, numDocs, k, null); + } + + // Validate KNN search on a KNN index by generating the query vector from the number of documents in the index + public void validateKNNSearch(String testIndex, String testField, int dimension, int numDocs, int k, Map methodParameters) + throws Exception { float[] queryVector = new float[dimension]; Arrays.fill(queryVector, (float) numDocs); - Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(testField, queryVector, k), k); + Response searchResponse = searchKNNIndex( + testIndex, + KNNQueryBuilder.builder().k(k).methodParameters(methodParameters).fieldName(testField).vector(queryVector).build(), + k + ); + List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), testField); assertEquals(k, results.size());