From 2c33ad89439ee18c3d2cbfafc43ade31ef400e61 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 20 Sep 2022 11:57:20 -0700 Subject: [PATCH] Refactor unit tests for codec (#562) * Refactor unit test for codec for easier lucene version upgrades Signed-off-by: Martin Gaievski --- .../index/codec/KNN920Codec/KNN920Codec.java | 5 +- .../codec/KNN920Codec/KNN920CodecTests.java | 122 ++---------------- .../knn/index/codec/KNNCodecTestCase.java | 108 ++++++++++++++++ 3 files changed, 119 insertions(+), 116 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java index 8d8e664ce..26abcea60 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920Codec.java @@ -11,6 +11,7 @@ import org.apache.lucene.codecs.DocValuesFormat; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNNFormatFacade; import org.opensearch.knn.index.codec.KNNFormatFactory; @@ -27,7 +28,7 @@ public final class KNN920Codec extends FilterCodec { private static final String KNN920 = "KNN920Codec"; private final KNNFormatFacade knnFormatFacade; - private final KNN920PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; /** * No arg constructor that uses Lucene91 as the delegate @@ -43,7 +44,7 @@ public KNN920Codec() { * @param knnVectorsFormat per field format for KnnVector */ @Builder - public KNN920Codec(Codec delegate, KNN920PerFieldKnnVectorsFormat knnVectorsFormat) { + public KNN920Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { super(KNN920, delegate); knnFormatFacade = KNNFormatFactory.createKNN920Format(delegate); perFieldKnnVectorsFormat = knnVectorsFormat; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java index 8a3233fba..e4f848f6a 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920CodecTests.java @@ -5,43 +5,16 @@ package org.opensearch.knn.index.codec.KNN920Codec; -import org.apache.lucene.document.Document; -import org.apache.lucene.document.FieldType; -import org.apache.lucene.document.KnnVectorField; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.SerialMergeScheduler; -import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.store.Directory; -import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.index.KNNMethodContext; -import org.opensearch.knn.index.MethodComponentContext; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.KNNCodecTestCase; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; -import org.opensearch.knn.index.query.KNNQueryFactory; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.watcher.ResourceWatcherService; import java.io.IOException; -import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutionException; +import java.util.function.Function; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate; public class KNN920CodecTests extends KNNCodecTestCase { @@ -55,93 +28,14 @@ public void testBuildFromModelTemplate() throws InterruptedException, ExecutionE } public void testKnnVectorIndex() throws Exception { - final String fieldName = "test_vector"; - final String field1Name = "my_vector"; - final MapperService mapperService = mock(MapperService.class); - final KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.LUCENE, - SpaceType.L2, - new MethodComponentContext(METHOD_HNSW, Map.of(HNSW_ALGO_M, 16, HNSW_ALGO_EF_CONSTRUCTION, 256)) - ); - final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldMapper.KNNVectorFieldType( - fieldName, - Map.of(), - 3, - knnMethodContext - ); - final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldMapper.KNNVectorFieldType( - field1Name, - Map.of(), - 2, - knnMethodContext - ); - when(mapperService.fieldType(eq(fieldName))).thenReturn(mappedFieldType1); - when(mapperService.fieldType(eq(field1Name))).thenReturn(mappedFieldType2); + Function perFieldKnnVectorsFormatProvider = ( + mapperService) -> new KNN920PerFieldKnnVectorsFormat(Optional.of(mapperService)); - var knnVectorsFormat = spy(new KNN920PerFieldKnnVectorsFormat(Optional.of(mapperService))); - - final KNN920Codec actualCodec = KNN920Codec.builder() + Function knnCodecProvider = (knnVectorFormat) -> KNN920Codec.builder() .delegate(createKNN92DefaultDelegate()) - .knnVectorsFormat(knnVectorsFormat) + .knnVectorsFormat(knnVectorFormat) .build(); - final KNN920Codec codec = KNN920Codec.builder().delegate(createKNN92DefaultDelegate()).knnVectorsFormat(knnVectorsFormat).build(); - setUpMockClusterService(); - Directory dir = newFSDirectory(createTempDir()); - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergeScheduler(new SerialMergeScheduler()); - iwc.setCodec(codec); - - /** - * Add doc with field "test_vector" - */ - final FieldType luceneFieldType = KnnVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); - float[] array = { 1.0f, 3.0f, 4.0f }; - KnnVectorField vectorField = new KnnVectorField(fieldName, array, luceneFieldType); - RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); - Document doc = new Document(); - doc.add(vectorField); - writer.addDocument(doc); - writer.commit(); - IndexReader reader = writer.getReader(); - writer.close(); - - verify(knnVectorsFormat).getKnnVectorsFormatForField(anyString()); - - IndexSearcher searcher = new IndexSearcher(reader); - Query query = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", fieldName, new float[] { 1.0f, 0.0f, 0.0f }, 1); - - assertEquals(1, searcher.count(query)); - - reader.close(); - - /** - * Add doc with field "my_vector" - */ - IndexWriterConfig iwc1 = newIndexWriterConfig(); - iwc1.setMergeScheduler(new SerialMergeScheduler()); - iwc1.setCodec(actualCodec); - writer = new RandomIndexWriter(random(), dir, iwc1); - final FieldType luceneFieldType1 = KnnVectorField.createFieldType(2, VectorSimilarityFunction.EUCLIDEAN); - float[] array1 = { 6.0f, 14.0f }; - KnnVectorField vectorField1 = new KnnVectorField(field1Name, array1, luceneFieldType1); - Document doc1 = new Document(); - doc1.add(vectorField1); - writer.addDocument(doc1); - IndexReader reader1 = writer.getReader(); - writer.close(); - ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); - NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); - - verify(knnVectorsFormat, times(2)).getKnnVectorsFormatForField(anyString()); - - IndexSearcher searcher1 = new IndexSearcher(reader1); - Query query1 = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", field1Name, new float[] { 1.0f, 0.0f }, 1); - - assertEquals(1, searcher1.count(query1)); - reader1.close(); - dir.close(); - resourceWatcherService.close(); - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); + testKnnVectorIndex(knnCodecProvider, perFieldKnnVectorsFormatProvider); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index fce944ee6..7991db135 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -7,11 +7,19 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec; +import org.opensearch.knn.index.query.KNNQueryFactory; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.KNNSettings; @@ -45,14 +53,24 @@ import java.time.ZonedDateTime; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutionException; +import java.util.function.Function; import java.util.stream.Collectors; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.Version.CURRENT; +import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; @@ -72,6 +90,8 @@ public class KNNCodecTestCase extends KNNTestCase { sampleFieldType.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); sampleFieldType.freeze(); } + private static final String FIELD_NAME_ONE = "test_vector_one"; + private static final String FIELD_NAME_TWO = "test_vector_two"; protected void setUpMockClusterService() { ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS); @@ -253,4 +273,92 @@ public void testWriteByOldCodec(Codec codec) throws IOException { dir.close(); NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); } + + public void testKnnVectorIndex( + final Function codecProvider, + final Function perFieldKnnVectorsFormatProvider + ) throws Exception { + final MapperService mapperService = mock(MapperService.class); + final KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.L2, + new MethodComponentContext(METHOD_HNSW, Map.of(HNSW_ALGO_M, 16, HNSW_ALGO_EF_CONSTRUCTION, 256)) + ); + final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldMapper.KNNVectorFieldType( + FIELD_NAME_ONE, + Map.of(), + 3, + knnMethodContext + ); + final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldMapper.KNNVectorFieldType( + FIELD_NAME_TWO, + Map.of(), + 2, + knnMethodContext + ); + when(mapperService.fieldType(eq(FIELD_NAME_ONE))).thenReturn(mappedFieldType1); + when(mapperService.fieldType(eq(FIELD_NAME_TWO))).thenReturn(mappedFieldType2); + + var perFieldKnnVectorsFormatSpy = spy(perFieldKnnVectorsFormatProvider.apply(mapperService)); + final Codec codec = codecProvider.apply(perFieldKnnVectorsFormatSpy); + + setUpMockClusterService(); + Directory dir = newFSDirectory(createTempDir()); + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergeScheduler(new SerialMergeScheduler()); + iwc.setCodec(codec); + + /** + * Add doc with field "test_vector_one" + */ + final FieldType luceneFieldType = KnnVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); + float[] array = { 1.0f, 3.0f, 4.0f }; + KnnVectorField vectorField = new KnnVectorField(FIELD_NAME_ONE, array, luceneFieldType); + RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + doc.add(vectorField); + writer.addDocument(doc); + writer.commit(); + IndexReader reader = writer.getReader(); + writer.close(); + + verify(perFieldKnnVectorsFormatSpy).getKnnVectorsFormatForField(anyString()); + + IndexSearcher searcher = new IndexSearcher(reader); + Query query = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", FIELD_NAME_ONE, new float[] { 1.0f, 0.0f, 0.0f }, 1); + + assertEquals(1, searcher.count(query)); + + reader.close(); + + /** + * Add doc with field "test_vector_two" + */ + IndexWriterConfig iwc1 = newIndexWriterConfig(); + iwc1.setMergeScheduler(new SerialMergeScheduler()); + iwc1.setCodec(codec); + writer = new RandomIndexWriter(random(), dir, iwc1); + final FieldType luceneFieldType1 = KnnVectorField.createFieldType(2, VectorSimilarityFunction.EUCLIDEAN); + float[] array1 = { 6.0f, 14.0f }; + KnnVectorField vectorField1 = new KnnVectorField(FIELD_NAME_TWO, array1, luceneFieldType1); + Document doc1 = new Document(); + doc1.add(vectorField1); + writer.addDocument(doc1); + IndexReader reader1 = writer.getReader(); + writer.close(); + ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService(); + NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); + + verify(perFieldKnnVectorsFormatSpy, times(2)).getKnnVectorsFormatForField(anyString()); + + IndexSearcher searcher1 = new IndexSearcher(reader1); + Query query1 = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", FIELD_NAME_TWO, new float[] { 1.0f, 0.0f }, 1); + + assertEquals(1, searcher1.count(query1)); + + reader1.close(); + dir.close(); + resourceWatcherService.close(); + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); + } }