Skip to content

Commit

Permalink
Revert 'Support script score when doc value is disabled' (#1662)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan authored Apr 30, 2024
1 parent e608d2d commit bd2f403
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 186 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Support filter and nested field in faiss engine radial search [#1652](https://github.com/opensearch-project/k-NN/pull/1652)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604)
* Remove unnecessary toString conversion of vector field and added some minor optimization in KNNCodec [1613](https://github.com/opensearch-project/k-NN/pull/1613)
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

package org.opensearch.knn.index;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
Expand Down Expand Up @@ -40,29 +39,10 @@ public long ramBytesUsed() {
@Override
public ScriptDocValues<float[]> getScriptValues() {
try {
FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName);
if (fieldInfo == null) {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
}

DocIdSetIterator values;
if (fieldInfo.hasVectorValues()) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
values = reader.getFloatVectorValues(fieldName);
break;
case BYTE:
values = reader.getByteVectorValues(fieldName);
break;
default:
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());
}
} else {
values = DocValues.getBinary(reader, fieldName);
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
BinaryDocValues values = DocValues.getBinary(reader, fieldName);
return new KNNVectorScriptDocValues(values, fieldName, vectorDataType);
} catch (IOException e) {
throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e);
throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e);
}
}

Expand Down
108 changes: 11 additions & 97 deletions src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@
package org.opensearch.knn.index;

import java.io.IOException;
import java.util.Objects;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.ExceptionsHelper;
import org.opensearch.index.fielddata.ScriptDocValues;

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {
import java.io.IOException;

@RequiredArgsConstructor
public final class KNNVectorScriptDocValues extends ScriptDocValues<float[]> {

private final DocIdSetIterator vectorValues;
private final BinaryDocValues binaryDocValues;
private final String fieldName;
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;

@Override
public void setNextDocId(int docId) throws IOException {
docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId;
if (binaryDocValues.advanceExact(docId)) {
docExists = true;
return;
}
docExists = false;
}

public float[] getValue() {
Expand All @@ -43,14 +44,12 @@ public float[] getValue() {
throw new IllegalStateException(errorMessage);
}
try {
return doGetValue();
return vectorDataType.getVectorFromBytesRef(binaryDocValues.binaryValue());
} catch (IOException e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}

protected abstract float[] doGetValue() throws IOException;

@Override
public int size() {
return docExists ? 1 : 0;
Expand All @@ -60,89 +59,4 @@ public int size() {
public float[] get(int i) {
throw new UnsupportedOperationException("knn vector does not support this operation");
}

/**
* Creates a KNNVectorScriptDocValues object based on the provided parameters.
*
* @param values The DocIdSetIterator representing the vector values.
* @param fieldName The name of the field.
* @param vectorDataType The data type of the vector.
* @return A KNNVectorScriptDocValues object based on the type of the values.
* @throws IllegalArgumentException If the type of values is unsupported.
*/
public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) {
Objects.requireNonNull(values, "values must not be null");
if (values instanceof ByteVectorValues) {
return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof FloatVectorValues) {
return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType);
} else if (values instanceof BinaryDocValues) {
return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType);
} else {
throw new IllegalArgumentException("Unsupported values type: " + values.getClass());
}
}

private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues {
private final ByteVectorValues values;

KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
byte[] bytes = values.vectorValue();
float[] value = new float[bytes.length];
for (int i = 0; i < bytes.length; i++) {
value[i] = (float) bytes[i];
}
return value;
}
}

private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues {
private final FloatVectorValues values;

KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return values.vectorValue();
}
}

private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues {
private final BinaryDocValues values;

KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) {
super(values, field, type);
this.values = values;
}

@Override
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}
}

/**
* Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type.
*
* @param fieldName The name of the field.
* @param type The data type of the vector.
* @return An empty KNNVectorScriptDocValues object.
*/
public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) {
return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) {
@Override
protected float[] doGetValue() throws IOException {
throw new UnsupportedOperationException("empty values");
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

package org.opensearch.knn.index;

import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.knn.KNNTestCase;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.document.BinaryDocValuesField;
Expand Down Expand Up @@ -41,39 +33,26 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase {
public void setUp() throws Exception {
super.setUp();
directory = newDirectory();
Class<? extends DocIdSetIterator> valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class);
createKNNVectorDocument(directory, valuesClass);
createKNNVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReader leafReader = reader.getContext().leaves().get(0).reader();
DocIdSetIterator vectorValues;
if (BinaryDocValues.class.equals(valuesClass)) {
vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME);
} else if (ByteVectorValues.class.equals(valuesClass)) {
vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME);
} else {
vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME);
}

scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME),
MOCK_INDEX_FIELD_NAME,
VectorDataType.FLOAT
);
}

private void createKNNVectorDocument(Directory directory, Class<? extends DocIdSetIterator> valuesClass) throws IOException {
private void createKNNVectorDocument(Directory directory) throws IOException {
IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random()));
IndexWriter writer = new IndexWriter(directory, conf);
Document knnDocument = new Document();
Field field;
if (BinaryDocValues.class.equals(valuesClass)) {
field = new BinaryDocValuesField(
knnDocument.add(
new BinaryDocValuesField(
MOCK_INDEX_FIELD_NAME,
new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue()
);
} else if (ByteVectorValues.class.equals(valuesClass)) {
field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA);
} else {
field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA);
}

knnDocument.add(field);
)
);
writer.addDocument(knnDocument);
writer.commit();
writer.close();
Expand Down Expand Up @@ -105,18 +84,4 @@ public void testSize() throws IOException {
public void testGet() throws IOException {
expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
}

public void testUnsupportedValues() throws IOException {
expectThrows(
IllegalArgumentException.class,
() -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT)
);
}

public void testEmptyValues() throws IOException {
KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT);
assertEquals(0, values.size());
scriptDocValues.setNextDocId(0);
assertEquals(0, values.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() {
createKNNFloatVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
return KNNVectorScriptDocValues.create(
return new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME),
VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME,
VectorDataType.FLOAT
Expand All @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() {
createKNNByteVectorDocument(directory);
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
return KNNVectorScriptDocValues.create(
return new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME),
VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME,
VectorDataType.BYTE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx
if (scriptDocValues == null) {
reader = DirectoryReader.open(directory);
LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0);
scriptDocValues = KNNVectorScriptDocValues.create(
scriptDocValues = new KNNVectorScriptDocValues(
leafReaderContext.reader().getBinaryDocValues(fieldName),
fieldName,
VectorDataType.FLOAT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,16 +612,7 @@ private List<String> createMappers(int dimensions) throws Exception {
dimensions,
KNNConstants.METHOD_HNSW,
KNNEngine.LUCENE.getName(),
SpaceType.DEFAULT.getValue(),
true
),
createKnnIndexMapping(
FIELD_NAME,
dimensions,
KNNConstants.METHOD_HNSW,
KNNEngine.LUCENE.getName(),
SpaceType.DEFAULT.getValue(),
false
SpaceType.DEFAULT.getValue()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,7 @@ public void testL2ScriptingWithLuceneBackedIndex() throws Exception {
new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())
);
properties.add(
new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2")
.knnMethodContext(knnMethodContext)
.docValues(randomBoolean())
new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext)
);

String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Expand Down

0 comments on commit bd2f403

Please sign in to comment.