Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: panguixin <[email protected]>
  • Loading branch information
bugmakerrrrrr committed Mar 28, 2024
1 parent 32762a5 commit f30c093
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,18 @@ public long ramBytesUsed() {
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
DocIdSetIterator values = null;
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
System.out.println(fieldInfo);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}

DocIdSetIterator values = null;
if (fieldInfo.hasVectorValues()) {
values = fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32
? reader.getFloatVectorValues(fieldName)
: reader.getByteVectorValues(fieldName);
System.out.println("use vector values");
} else {
values = DocValues.getBinary(reader, fieldName);
System.out.println("use binary values");
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index;

import java.io.IOException;
import java.util.Objects;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -60,7 +61,17 @@ public float[] get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

/**
* Creates a KNNVectorScriptDocValues object based on the provided parameters.
*
* @param values The DocIdSetIterator representing the vector values.
* @param fieldName The name of the field.
* @param vectorDataType The data type of the vector.
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
Expand Down Expand Up @@ -118,4 +129,20 @@ protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromDocValues(values.binaryValue());
}
}

/**
* Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type.
*
* @param fieldName The name of the field.
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

package org.opensearch.knn.index;

import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.knn.KNNTestCase;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.document.BinaryDocValuesField;
Expand All @@ -13,7 +22,6 @@
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.store.Directory;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -24,6 +32,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {

private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name";
private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f };
private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 };
private KNNVectorScriptDocValues scriptDocValues;
private Directory directory;
private DirectoryReader reader;
Expand All @@ -32,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {
public void setUp() throws Exception {
super.setUp();
directory = newDirectory();
createKNNVectorDocument(directory);
Class<? extends DocIdSetIterator> valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class);
createKNNVectorDocument(directory, valuesClass);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = KNNVectorScriptDocValues.create(
leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME),
MOCK_INDEX_FIELD_NAME,
VectorDataType.FLOAT
);
LeafReader leafReader = reader.getContext().leaves().get(0).reader();
DocIdSetIterator vectorValues;
if (BinaryDocValues.class.equals(valuesClass)) {
vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME);
} else if (ByteVectorValues.class.equals(valuesClass)) {
vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME);
} else {
vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME);
}

scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
}

private void createKNNVectorDocument(Directory directory) throws IOException {
private void createKNNVectorDocument(Directory directory, Class<? extends DocIdSetIterator> valuesClass) throws IOException {
IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random()));
IndexWriter writer = new IndexWriter(directory, conf);
Document knnDocument = new Document();
knnDocument.add(
new BinaryDocValuesField(
Field field;
if (BinaryDocValues.class.equals(valuesClass)) {
field = new BinaryDocValuesField(
MOCK_INDEX_FIELD_NAME,
new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue()
)
);
);
} else if (ByteVectorValues.class.equals(valuesClass)) {
field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA);
} else {
field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA);
}

knnDocument.add(field);
writer.addDocument(knnDocument);
writer.commit();
writer.close();
Expand Down Expand Up @@ -83,4 +105,18 @@ public void testSize() throws IOException {
public void testGet() throws IOException {
expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
}

public void testUnsupportedValues() throws IOException {
expectThrows(
IllegalArgumentException.class,
() -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT)
);
}

public void testEmptyValues() throws IOException {
KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
assertEquals(0, values.size());
scriptDocValues.setNextDocId(0);
assertEquals(0, values.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.opensearch.knn.plugin.script;

import java.io.IOException;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Request;
Expand All @@ -21,6 +23,7 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.script.Script;

import java.util.ArrayList;
Expand All @@ -35,12 +38,29 @@
import static org.hamcrest.Matchers.containsString;

public class KNNScriptScoringIT extends KNNRestTestCase {
private void randomCreateKNNIndex() throws IOException {
if (randomBoolean()) {
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
} else {
createKnnIndex(
INDEX_NAME,
createKnnIndexMapping(
FIELD_NAME,
2,
KNNConstants.METHOD_HNSW,
KNNEngine.LUCENE.getName(),
SpaceType.DEFAULT.getValue(),
randomBoolean()
)
);
}
}

public void testKNNL2ScriptScore() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();
Float[] f1 = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Expand Down Expand Up @@ -93,7 +113,7 @@ public void testKNNL1ScriptScore() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();
Float[] f1 = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Expand Down Expand Up @@ -146,7 +166,7 @@ public void testKNNLInfScriptScore() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();
Float[] f1 = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Expand Down Expand Up @@ -199,7 +219,7 @@ public void testKNNCosineScriptScore() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();
Float[] f1 = { 1.0f, -1.0f };
addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1);

Expand Down Expand Up @@ -251,7 +271,7 @@ public void testKNNInvalidSourceScript() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();

/**
* Construct Search Request
Expand Down Expand Up @@ -293,7 +313,7 @@ public void testInvalidSpace() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();

/**
* Construct Search Request
Expand All @@ -316,7 +336,7 @@ public void testMissingParamsInScript() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();

/**
* Construct Search Request
Expand Down Expand Up @@ -349,7 +369,7 @@ public void testUnequalDimensions() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();
Float[] f1 = { 1.0f, -1.0f };
addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1);

Expand All @@ -372,7 +392,7 @@ public void testKNNScoreforNonVectorDocument() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();
Float[] f1 = { 1.0f, 1.0f };
addDocWithNumericField(INDEX_NAME, "0", "price", 10);
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);
Expand Down Expand Up @@ -636,7 +656,7 @@ public void testKNNInnerProdScriptScore() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();
Float[] f1 = { -2.0f, -2.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Expand Down Expand Up @@ -690,7 +710,7 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
randomCreateKNNIndex();
Float[] f1 = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ protected String createMapping(List<MappingProperty> properties) throws IOExcept
builder.field("dimension", property.getDimension());
}

if (property.getDocValues() != null) {
builder.field("doc_values", property.getDocValues());
}

if (property.getKnnMethodContext() != null) {
builder.startObject(KNNConstants.KNN_METHOD);
property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down Expand Up @@ -554,12 +558,14 @@ public void testScriptedMetricIsSupported() throws Exception {
public void testL2ScriptingWithLuceneBackedIndex() throws Exception {
List<MappingProperty> properties = new ArrayList<>();
KNNMethodContext knnMethodContext = new KNNMethodContext(
KNNEngine.NMSLIB,
KNNEngine.LUCENE,
SpaceType.DEFAULT,
new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())
);
properties.add(
new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext)
new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2")
.knnMethodContext(knnMethodContext)
.docValues(randomBoolean())
);

String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Expand All @@ -585,6 +591,7 @@ static class MappingProperty {
private String dimension;

private KNNMethodContext knnMethodContext;
private Boolean docValues;

MappingProperty(String name, String type) {
this.name = name;
Expand All @@ -601,6 +608,11 @@ MappingProperty knnMethodContext(KNNMethodContext knnMethodContext) {
return this;
}

MappingProperty docValues(boolean docValues) {
this.docValues = docValues;
return this;
}

KNNMethodContext getKnnMethodContext() {
return knnMethodContext;
}
Expand All @@ -616,5 +628,9 @@ String getName() {
String getType() {
return type;
}

Boolean getDocValues() {
return docValues;
}
}
}
Loading

0 comments on commit f30c093

Please sign in to comment.