Skip to content

Commit

Permalink
added an option to receive a base64 encoded vector as an input
Browse files Browse the repository at this point in the history
  • Loading branch information
lior-k committed Apr 7, 2018
1 parent 5bfd46f commit 54fb4ea
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 25 deletions.
31 changes: 31 additions & 0 deletions src/main/java/com/liorkn/elasticsearch/Util.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.liorkn.elasticsearch;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.util.Base64;

/**
* Created by Lior Knaany on 4/7/18.
*/
public class Util {

public static final double[] convertBase64ToArray(String base64Str) {
final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes());
final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer();

final double[] dims = new double[doubleBuffer.capacity()];
doubleBuffer.get(dims);
return dims;
}

public static final String convertArrayToBase64(double[] array) {
final int capacity = 8 * array.length;
final ByteBuffer bb = ByteBuffer.allocate(capacity);
for (int i = 0; i < array.length; i++) {
bb.putDouble(array[i]);
}
bb.rewind();
final ByteBuffer encodedBB = Base64.getEncoder().encode(bb);
return new String(encodedBB.array());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

package com.liorkn.elasticsearch.script;

import com.liorkn.elasticsearch.Util;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.store.ByteArrayDataInput;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.script.ExecutableScript;
Expand All @@ -32,14 +32,13 @@
*/
public final class VectorScoreScript implements LeafSearchScript, ExecutableScript {

// private final static ESLogger logger = ESLoggerFactory.getLogger(VectorScoreScript.class.getName());
public final static String SCRIPT_NAME = "binary_vector_score";
public static final String SCRIPT_NAME = "binary_vector_score";

private static final int DOUBLE_SIZE = 8;

// the field containing the vectors to be scored against
public final String field;

private static final int DOUBLE_SIZE = 8;

private int docId;
private BinaryDocValues binaryEmbeddingReader;

Expand All @@ -49,24 +48,15 @@ public final class VectorScoreScript implements LeafSearchScript, ExecutableScri
private final boolean cosine;

@Override
public void setScorer(Scorer scorer) {
}
public void setSource(Map<String, Object> source) {
}
public float runAsFloat() {
return ((Number)this.run()).floatValue();
}

public long runAsLong() {
return ((Number)this.run()).longValue();
}
@Override
public double runAsDouble() {
return ((Number)this.run()).doubleValue();
}
public Object unwrap(Object value) {
return value;
}

@Override
public void setNextVar(String name, Object value) {}
@Override
public void setDocument(int docId) {
this.docId = docId;
Expand Down Expand Up @@ -127,17 +117,27 @@ public VectorScoreScript(Map<String, Object> params) {
this.field = field.toString();

// get query inputVector - convert to primitive
final ArrayList<Double> tmp = (ArrayList<Double>) params.get("vector");
this.inputVector = new double[tmp.size()];
for (int i = 0; i < inputVector.length; i++) {
inputVector[i] = tmp.get(i);

final Object vector = params.get("vector");
if(vector != null) {
final ArrayList<Double> tmp = (ArrayList<Double>) vector;
inputVector = new double[tmp.size()];
for (int i = 0; i < inputVector.length; i++) {
inputVector[i] = tmp.get(i);
}
} else {
final Object encodedVector = params.get("encoded_vector");
if(encodedVector == null) {
throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter");
}
inputVector = Util.convertBase64ToArray((String) encodedVector);
}

if(cosine) {
// calc magnitude
double queryVectorNorm = 0.0;
// compute query inputVector norm once
for (double v : this.inputVector) {
for (double v : inputVector) {
queryVectorNorm += v * v;
}
magnitude = Math.sqrt(queryVectorNorm);
Expand All @@ -146,9 +146,7 @@ public VectorScoreScript(Map<String, Object> params) {
}
}

@Override
public void setNextVar(String name, Object value) {
}


/**
* Called for each document
Expand Down

0 comments on commit 54fb4ea

Please sign in to comment.