From 54fb4ea0f9438c00f25d317c98e157bf68a786ad Mon Sep 17 00:00:00 2001 From: Lior Knaany Date: Sat, 7 Apr 2018 15:19:50 +0300 Subject: [PATCH] added an option to receive a base64 encoded vector as an input --- .../java/com/liorkn/elasticsearch/Util.java | 31 ++++++++++++ .../script/VectorScoreScript.java | 48 +++++++++---------- 2 files changed, 54 insertions(+), 25 deletions(-) create mode 100644 src/main/java/com/liorkn/elasticsearch/Util.java diff --git a/src/main/java/com/liorkn/elasticsearch/Util.java b/src/main/java/com/liorkn/elasticsearch/Util.java new file mode 100644 index 0000000..de81af8 --- /dev/null +++ b/src/main/java/com/liorkn/elasticsearch/Util.java @@ -0,0 +1,31 @@ +package com.liorkn.elasticsearch; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.util.Base64; + +/** + * Created by Lior Knaany on 4/7/18. + */ +public class Util { + + public static final double[] convertBase64ToArray(String base64Str) { + final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes()); + final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer(); + + final double[] dims = new double[doubleBuffer.capacity()]; + doubleBuffer.get(dims); + return dims; + } + + public static final String convertArrayToBase64(double[] array) { + final int capacity = 8 * array.length; + final ByteBuffer bb = ByteBuffer.allocate(capacity); + for (int i = 0; i < array.length; i++) { + bb.putDouble(array[i]); + } + bb.rewind(); + final ByteBuffer encodedBB = Base64.getEncoder().encode(bb); + return new String(encodedBB.array()); + } +} diff --git a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java index 73adb0b..7e60099 100755 --- a/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java +++ b/src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java @@ -14,8 +14,8 @@ package com.liorkn.elasticsearch.script; +import com.liorkn.elasticsearch.Util; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.search.Scorer; import org.apache.lucene.store.ByteArrayDataInput; import org.elasticsearch.common.Nullable; import org.elasticsearch.script.ExecutableScript; @@ -32,14 +32,13 @@ */ public final class VectorScoreScript implements LeafSearchScript, ExecutableScript { - // private final static ESLogger logger = ESLoggerFactory.getLogger(VectorScoreScript.class.getName()); - public final static String SCRIPT_NAME = "binary_vector_score"; + public static final String SCRIPT_NAME = "binary_vector_score"; + + private static final int DOUBLE_SIZE = 8; // the field containing the vectors to be scored against public final String field; - private static final int DOUBLE_SIZE = 8; - private int docId; private BinaryDocValues binaryEmbeddingReader; @@ -49,24 +48,15 @@ public final class VectorScoreScript implements LeafSearchScript, ExecutableScri private final boolean cosine; @Override - public void setScorer(Scorer scorer) { - } - public void setSource(Map source) { - } - public float runAsFloat() { - return ((Number)this.run()).floatValue(); - } - public long runAsLong() { return ((Number)this.run()).longValue(); } + @Override public double runAsDouble() { return ((Number)this.run()).doubleValue(); } - public Object unwrap(Object value) { - return value; - } - + @Override + public void setNextVar(String name, Object value) {} @Override public void setDocument(int docId) { this.docId = docId; @@ -127,17 +117,27 @@ public VectorScoreScript(Map params) { this.field = field.toString(); // get query inputVector - convert to primitive - final ArrayList tmp = (ArrayList) params.get("vector"); - this.inputVector = new double[tmp.size()]; - for (int i = 0; i < inputVector.length; i++) { - inputVector[i] = tmp.get(i); + + final Object vector = params.get("vector"); + if(vector != null) { + final ArrayList tmp = (ArrayList) vector; + inputVector = new double[tmp.size()]; + for (int i = 0; i < inputVector.length; i++) { + inputVector[i] = tmp.get(i); + } + } else { + final Object encodedVector = params.get("encoded_vector"); + if(encodedVector == null) { + throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter"); + } + inputVector = Util.convertBase64ToArray((String) encodedVector); } if(cosine) { // calc magnitude double queryVectorNorm = 0.0; // compute query inputVector norm once - for (double v : this.inputVector) { + for (double v : inputVector) { queryVectorNorm += v * v; } magnitude = Math.sqrt(queryVectorNorm); @@ -146,9 +146,7 @@ public VectorScoreScript(Map params) { } } - @Override - public void setNextVar(String name, Object value) { - } + /** * Called for each document