Skip to content

Commit

Permalink
Refactor unit tests for codec (#562)
Browse files Browse the repository at this point in the history
* Refactor unit test for codec for easier lucene version upgrades

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Sep 28, 2022
1 parent dfa79b2 commit 2c33ad8
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.KNNFormatFactory;

Expand All @@ -27,7 +28,7 @@ public final class KNN920Codec extends FilterCodec {
private static final String KNN920 = "KNN920Codec";

private final KNNFormatFacade knnFormatFacade;
private final KNN920PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene91 as the delegate
Expand All @@ -43,7 +44,7 @@ public KNN920Codec() {
* @param knnVectorsFormat per field format for KnnVector
*/
@Builder
public KNN920Codec(Codec delegate, KNN920PerFieldKnnVectorsFormat knnVectorsFormat) {
public KNN920Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(KNN920, delegate);
knnFormatFacade = KNNFormatFactory.createKNN920Format(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,16 @@

package org.opensearch.knn.index.codec.KNN920Codec;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.SerialMergeScheduler;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.KNNCodecTestCase;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.index.query.KNNQueryFactory;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.watcher.ResourceWatcherService;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.index.codec.KNNCodecFactory.CodecDelegateFactory.createKNN92DefaultDelegate;

public class KNN920CodecTests extends KNNCodecTestCase {
Expand All @@ -55,93 +28,14 @@ public void testBuildFromModelTemplate() throws InterruptedException, ExecutionE
}

public void testKnnVectorIndex() throws Exception {
final String fieldName = "test_vector";
final String field1Name = "my_vector";
final MapperService mapperService = mock(MapperService.class);
final KNNMethodContext knnMethodContext = new KNNMethodContext(
KNNEngine.LUCENE,
SpaceType.L2,
new MethodComponentContext(METHOD_HNSW, Map.of(HNSW_ALGO_M, 16, HNSW_ALGO_EF_CONSTRUCTION, 256))
);
final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldMapper.KNNVectorFieldType(
fieldName,
Map.of(),
3,
knnMethodContext
);
final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldMapper.KNNVectorFieldType(
field1Name,
Map.of(),
2,
knnMethodContext
);
when(mapperService.fieldType(eq(fieldName))).thenReturn(mappedFieldType1);
when(mapperService.fieldType(eq(field1Name))).thenReturn(mappedFieldType2);
Function<MapperService, PerFieldKnnVectorsFormat> perFieldKnnVectorsFormatProvider = (
mapperService) -> new KNN920PerFieldKnnVectorsFormat(Optional.of(mapperService));

var knnVectorsFormat = spy(new KNN920PerFieldKnnVectorsFormat(Optional.of(mapperService)));

final KNN920Codec actualCodec = KNN920Codec.builder()
Function<PerFieldKnnVectorsFormat, Codec> knnCodecProvider = (knnVectorFormat) -> KNN920Codec.builder()
.delegate(createKNN92DefaultDelegate())
.knnVectorsFormat(knnVectorsFormat)
.knnVectorsFormat(knnVectorFormat)
.build();
final KNN920Codec codec = KNN920Codec.builder().delegate(createKNN92DefaultDelegate()).knnVectorsFormat(knnVectorsFormat).build();
setUpMockClusterService();
Directory dir = newFSDirectory(createTempDir());
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setMergeScheduler(new SerialMergeScheduler());
iwc.setCodec(codec);

/**
* Add doc with field "test_vector"
*/
final FieldType luceneFieldType = KnnVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN);
float[] array = { 1.0f, 3.0f, 4.0f };
KnnVectorField vectorField = new KnnVectorField(fieldName, array, luceneFieldType);
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc);
Document doc = new Document();
doc.add(vectorField);
writer.addDocument(doc);
writer.commit();
IndexReader reader = writer.getReader();
writer.close();

verify(knnVectorsFormat).getKnnVectorsFormatForField(anyString());

IndexSearcher searcher = new IndexSearcher(reader);
Query query = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", fieldName, new float[] { 1.0f, 0.0f, 0.0f }, 1);

assertEquals(1, searcher.count(query));

reader.close();

/**
* Add doc with field "my_vector"
*/
IndexWriterConfig iwc1 = newIndexWriterConfig();
iwc1.setMergeScheduler(new SerialMergeScheduler());
iwc1.setCodec(actualCodec);
writer = new RandomIndexWriter(random(), dir, iwc1);
final FieldType luceneFieldType1 = KnnVectorField.createFieldType(2, VectorSimilarityFunction.EUCLIDEAN);
float[] array1 = { 6.0f, 14.0f };
KnnVectorField vectorField1 = new KnnVectorField(field1Name, array1, luceneFieldType1);
Document doc1 = new Document();
doc1.add(vectorField1);
writer.addDocument(doc1);
IndexReader reader1 = writer.getReader();
writer.close();
ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService();
NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService);

verify(knnVectorsFormat, times(2)).getKnnVectorsFormatForField(anyString());

IndexSearcher searcher1 = new IndexSearcher(reader1);
Query query1 = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", field1Name, new float[] { 1.0f, 0.0f }, 1);

assertEquals(1, searcher1.count(query1));

reader1.close();
dir.close();
resourceWatcherService.close();
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close();
testKnnVectorIndex(knnCodecProvider, perFieldKnnVectorsFormatProvider);
}
}
108 changes: 108 additions & 0 deletions src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec;
import org.opensearch.knn.index.query.KNNQueryFactory;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.KNNSettings;
Expand Down Expand Up @@ -45,14 +53,24 @@
import java.time.ZonedDateTime;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.Version.CURRENT;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING;

Expand All @@ -72,6 +90,8 @@ public class KNNCodecTestCase extends KNNTestCase {
sampleFieldType.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512");
sampleFieldType.freeze();
}
private static final String FIELD_NAME_ONE = "test_vector_one";
private static final String FIELD_NAME_TWO = "test_vector_two";

protected void setUpMockClusterService() {
ClusterService clusterService = mock(ClusterService.class, RETURNS_DEEP_STUBS);
Expand Down Expand Up @@ -253,4 +273,92 @@ public void testWriteByOldCodec(Codec codec) throws IOException {
dir.close();
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close();
}

public void testKnnVectorIndex(
final Function<PerFieldKnnVectorsFormat, Codec> codecProvider,
final Function<MapperService, PerFieldKnnVectorsFormat> perFieldKnnVectorsFormatProvider
) throws Exception {
final MapperService mapperService = mock(MapperService.class);
final KNNMethodContext knnMethodContext = new KNNMethodContext(
KNNEngine.LUCENE,
SpaceType.L2,
new MethodComponentContext(METHOD_HNSW, Map.of(HNSW_ALGO_M, 16, HNSW_ALGO_EF_CONSTRUCTION, 256))
);
final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldMapper.KNNVectorFieldType(
FIELD_NAME_ONE,
Map.of(),
3,
knnMethodContext
);
final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldMapper.KNNVectorFieldType(
FIELD_NAME_TWO,
Map.of(),
2,
knnMethodContext
);
when(mapperService.fieldType(eq(FIELD_NAME_ONE))).thenReturn(mappedFieldType1);
when(mapperService.fieldType(eq(FIELD_NAME_TWO))).thenReturn(mappedFieldType2);

var perFieldKnnVectorsFormatSpy = spy(perFieldKnnVectorsFormatProvider.apply(mapperService));
final Codec codec = codecProvider.apply(perFieldKnnVectorsFormatSpy);

setUpMockClusterService();
Directory dir = newFSDirectory(createTempDir());
IndexWriterConfig iwc = newIndexWriterConfig();
iwc.setMergeScheduler(new SerialMergeScheduler());
iwc.setCodec(codec);

/**
* Add doc with field "test_vector_one"
*/
final FieldType luceneFieldType = KnnVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN);
float[] array = { 1.0f, 3.0f, 4.0f };
KnnVectorField vectorField = new KnnVectorField(FIELD_NAME_ONE, array, luceneFieldType);
RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc);
Document doc = new Document();
doc.add(vectorField);
writer.addDocument(doc);
writer.commit();
IndexReader reader = writer.getReader();
writer.close();

verify(perFieldKnnVectorsFormatSpy).getKnnVectorsFormatForField(anyString());

IndexSearcher searcher = new IndexSearcher(reader);
Query query = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", FIELD_NAME_ONE, new float[] { 1.0f, 0.0f, 0.0f }, 1);

assertEquals(1, searcher.count(query));

reader.close();

/**
* Add doc with field "test_vector_two"
*/
IndexWriterConfig iwc1 = newIndexWriterConfig();
iwc1.setMergeScheduler(new SerialMergeScheduler());
iwc1.setCodec(codec);
writer = new RandomIndexWriter(random(), dir, iwc1);
final FieldType luceneFieldType1 = KnnVectorField.createFieldType(2, VectorSimilarityFunction.EUCLIDEAN);
float[] array1 = { 6.0f, 14.0f };
KnnVectorField vectorField1 = new KnnVectorField(FIELD_NAME_TWO, array1, luceneFieldType1);
Document doc1 = new Document();
doc1.add(vectorField1);
writer.addDocument(doc1);
IndexReader reader1 = writer.getReader();
writer.close();
ResourceWatcherService resourceWatcherService = createDisabledResourceWatcherService();
NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService);

verify(perFieldKnnVectorsFormatSpy, times(2)).getKnnVectorsFormatForField(anyString());

IndexSearcher searcher1 = new IndexSearcher(reader1);
Query query1 = KNNQueryFactory.create(KNNEngine.LUCENE, "dummy", FIELD_NAME_TWO, new float[] { 1.0f, 0.0f }, 1);

assertEquals(1, searcher1.count(query1));

reader1.close();
dir.close();
resourceWatcherService.close();
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close();
}
}

0 comments on commit 2c33ad8

Please sign in to comment.