Skip to content

Commit

Permalink
[tokenizers] Return score for QA inference
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Dec 8, 2024
1 parent e4d184d commit 879626c
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private BertQaInference() {}

public static void main(String[] args) throws IOException, TranslateException, ModelException {
String answer = BertQaInference.predict();
logger.info("Answer: {}", answer);
logger.info("Output: {}", answer);
}

public static String predict() throws IOException, TranslateException, ModelException {
Expand All @@ -69,6 +69,7 @@ public static String predict() throws IOException, TranslateException, ModelExce
"djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2")
.optEngine("PyTorch")
.optTranslatorFactory(new QuestionAnsweringTranslatorFactory())
.optArgument("detail", true)
.optProgress(new ProgressBar())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import ai.djl.ModelException;
import ai.djl.testing.TestRequirements;
import ai.djl.translate.TranslateException;
import ai.djl.util.JsonUtils;

import com.google.gson.JsonObject;

import org.testng.Assert;
import org.testng.annotations.Test;
Expand All @@ -28,6 +31,8 @@ public void testBertQa() throws ModelException, TranslateException, IOException
TestRequirements.linux();

String result = BertQaInference.predict();
Assert.assertEquals(result, "december 2004");
JsonObject json = JsonUtils.GSON.fromJson(result, JsonObject.class);
String answer = json.get("answer").getAsString();
Assert.assertEquals(answer, "december 2004");
}
}
27 changes: 26 additions & 1 deletion extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ use tk::models::bpe::BPE;
use tk::tokenizer::{EncodeInput, Encoding};
use tk::utils::padding::{PaddingParams, PaddingStrategy};
use tk::utils::truncation::{TruncationParams, TruncationStrategy};
use tk::Tokenizer;
use tk::Offsets;
use tk::Tokenizer;

#[cfg(not(target_os = "android"))]
use tk::FromPretrainedParameters;
Expand Down Expand Up @@ -407,6 +407,31 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
array
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getSequenceIds<
'local,
>(
env: JNIEnv<'local>,
_: JObject,
handle: jlong,
) -> JLongArray<'local> {
let encoding = cast_handle::<Encoding>(handle);
let sequence_ids = encoding.get_sequence_ids();
let len = sequence_ids.len() as jsize;
let mut long_ids: Vec<jlong> = Vec::new();
for i in sequence_ids {
if let Some(sequence_id) = i {
long_ids.push(sequence_id as jlong)
} else {
long_ids.push(-1)
}
}

let array = env.new_long_array(len).unwrap();
env.set_long_array_region(&array, 0, &long_ids).unwrap();
array
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokens<
'local,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class Encoding {
private long[] typeIds;
private String[] tokens;
private long[] wordIds;
private long[] sequenceIds;
private long[] attentionMask;
private long[] specialTokenMask;
private CharSpan[] charTokenSpans;
Expand All @@ -36,6 +37,7 @@ protected Encoding(
long[] typeIds,
String[] tokens,
long[] wordIds,
long[] sequenceIds,
long[] attentionMask,
long[] specialTokenMask,
CharSpan[] charTokenSpans,
Expand All @@ -45,6 +47,7 @@ protected Encoding(
this.typeIds = typeIds;
this.tokens = tokens;
this.wordIds = wordIds;
this.sequenceIds = sequenceIds;
this.attentionMask = attentionMask;
this.specialTokenMask = specialTokenMask;
this.charTokenSpans = charTokenSpans;
Expand Down Expand Up @@ -109,6 +112,15 @@ public long[] getWordIds() {
return wordIds;
}

/**
* Returns the sequence ids.
*
* @return the sequence ids
*/
public long[] getSequenceIds() {
return sequenceIds;
}

/**
* Returns the attention masks.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding);
String[] tokens = TokenizersLibrary.LIB.getTokens(encoding);
long[] wordIds = TokenizersLibrary.LIB.getWordIds(encoding);
long[] sequenceIds = TokenizersLibrary.LIB.getSequenceIds(encoding);
long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding);
long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);
Expand All @@ -646,6 +647,7 @@ private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
typeIds,
tokens,
wordIds,
sequenceIds,
attentionMask,
specialTokenMask,
charSpans,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public native long[] batchEncodePair(

public native long[] getWordIds(long encoding);

public native long[] getSequenceIds(long encoding);

public native String[] getTokens(long encoding);

public native long[] getAttentionMask(long encoding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,32 @@
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.PairList;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** The translator for Huggingface question answering model. */
public class QuestionAnsweringTranslator implements Translator<QAInput, String> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private boolean detail;

QuestionAnsweringTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
boolean includeTokenTypes,
Batchifier batchifier,
boolean detail) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
this.detail = detail;
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -102,6 +109,27 @@ private String decode(NDList list, Encoding encoding) {
startLogits = startLogits.duplicate();
endLogits = endLogits.duplicate();
}
if (detail) {
// exclude undesired sequences
long[] sequenceIds = encoding.getSequenceIds();
List<Integer> undesired = new ArrayList<>();
for (int i = 0; i < sequenceIds.length; ++i) {
if (sequenceIds[i] == 0) {
undesired.add(i);
}
}
int[] idx = undesired.stream().mapToInt(Integer::intValue).toArray();
NDIndex ndIndex = new NDIndex("{}", list.getManager().create(idx));
startLogits.set(ndIndex, -100000f);
endLogits.set(ndIndex, -100000f);

// normalize
startLogits = startLogits.sub(startLogits.max()).exp();
startLogits = startLogits.div(startLogits.sum());
endLogits = endLogits.sub(endLogits.max()).exp();
endLogits = endLogits.div(endLogits.sum());
}

// exclude <CLS>, TODO: exclude impossible ids properly and handle max answer length
startLogits.set(new NDIndex(0), -100000);
endLogits.set(new NDIndex(0), -100000);
Expand All @@ -111,12 +139,26 @@ private String decode(NDList list, Encoding encoding) {
int tmp = startIdx;
startIdx = endIdx;
endIdx = tmp;
NDArray tmpArray = startLogits;
startLogits = endLogits;
endLogits = tmpArray;
}
long[] indices = encoding.getIds();
int len = endIdx - startIdx + 1;
long[] ids = new long[len];
System.arraycopy(indices, startIdx, ids, 0, len);
return tokenizer.decode(ids).trim();
String answer = tokenizer.decode(ids).trim();
if (detail) {
float score = startLogits.getFloat(startIdx) * endLogits.getFloat(endIdx);

Map<String, Object> dict = new ConcurrentHashMap<>();
dict.put("score", score);
dict.put("start", startIdx);
dict.put("end", endIdx);
dict.put("answer", answer);
return JsonUtils.toJson(dict);
}
return answer;
}

/**
Expand Down Expand Up @@ -149,6 +191,7 @@ public static final class Builder {
private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;
private boolean detail;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -176,6 +219,17 @@ public Builder optBatchifier(Batchifier batchifier) {
return this;
}

/**
* Sets if output detail for the {@link Translator}.
*
* @param detail true to output detail
* @return this builder
*/
public Builder optDetail(boolean detail) {
this.detail = detail;
return this;
}

/**
* Configures the builder with the model arguments.
*
Expand All @@ -184,6 +238,7 @@ public Builder optBatchifier(Batchifier batchifier) {
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optDetail(ArgumentsUtil.booleanValue(arguments, "detail"));
optBatchifier(Batchifier.fromString(batchifierStr));
}

Expand All @@ -194,7 +249,8 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public QuestionAnsweringTranslator build() throws IOException {
return new QuestionAnsweringTranslator(tokenizer, includeTokenTypes, batchifier);
return new QuestionAnsweringTranslator(
tokenizer, includeTokenTypes, batchifier, detail);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ public void testTokenizer() throws IOException {
long[] ids = {101, 8667, 117, 194, 112, 1155, 106, 1731, 1132, 1128, 100, 136, 102};
long[] typeIds = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
long[] wordIds = {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, -1};
long[] sequenceIds = {-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1};
long[] attentionMask = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
long[] specialTokenMask = {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1};

Assert.assertEquals(expected, encoding.getTokens());
Assert.assertEquals(ids, encoding.getIds());
Assert.assertEquals(typeIds, encoding.getTypeIds());
Assert.assertEquals(wordIds, encoding.getWordIds());
Assert.assertEquals(sequenceIds, encoding.getSequenceIds());
Assert.assertEquals(attentionMask, encoding.getAttentionMask());
Assert.assertEquals(specialTokenMask, encoding.getSpecialTokenMask());

Expand Down Expand Up @@ -104,6 +106,10 @@ public void testTokenizer() throws IOException {
Assert.assertEquals(charSpansExpected[i].getEnd(), charSpansResult[i].getEnd());
}

encoding = tokenizer.encode(inputs[0], inputs[1]);
sequenceIds = new long[] {-1, 0, 0, 0, 0, 0, 0, -1, 1, 1, 1, 1, 1, -1};
Assert.assertEquals(encoding.getSequenceIds(), sequenceIds);

Assert.assertThrows(() -> tokenizer.encode((String) null));
Assert.assertThrows(() -> tokenizer.encode(new String[] {null}));
Assert.assertThrows(() -> tokenizer.encode(null, null));
Expand Down

0 comments on commit 879626c

Please sign in to comment.