From e356f6e6b081609bf99759739e18da4d92fa36e9 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 16 Nov 2023 11:26:41 -0800 Subject: [PATCH] [tokenizer] Not returns overflow tokens by default --- .../tokenizers/HuggingFaceTokenizer.java | 87 +++++++++++++------ .../translator/FillMaskBatchTranslator.java | 2 +- .../translator/FillMaskTranslator.java | 2 +- .../tokenizers/HuggingFaceTokenizerTest.java | 6 +- 4 files changed, 67 insertions(+), 30 deletions(-) diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java index fa601e15525..de26f1b06c5 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizer.java @@ -44,6 +44,7 @@ public final class HuggingFaceTokenizer extends NativeResource implements private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class); private boolean addSpecialTokens; + private boolean withOverflowingTokens; private TruncationStrategy truncation; private PaddingStrategy padding; private int maxLength; @@ -64,6 +65,8 @@ private HuggingFaceTokenizer(long handle, Map options) { if (options != null) { val = options.getOrDefault("addSpecialTokens", "true"); addSpecialTokens = Boolean.parseBoolean(val); + val = options.getOrDefault("withOverflowingTokens", "false"); + withOverflowingTokens = Boolean.parseBoolean(val); modelMaxLength = ArgumentsUtil.intValue(options, "modelMaxLength", 512); if (options.containsKey("truncation")) { truncation = TruncationStrategy.fromValue(options.get("truncation")); @@ -203,11 +206,12 @@ public void close() { * @param text the input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence */ - public Encoding encode(String text, boolean addSpecialTokens) { + public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) { long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -217,7 +221,7 @@ public Encoding encode(String text, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence */ public Encoding encode(String text) { - return encode(text, addSpecialTokens); + return encode(text, addSpecialTokens, withOverflowingTokens); } /** @@ -227,12 +231,14 @@ public Encoding encode(String text) { * @param textPair the second input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence */ - public Encoding encode(String text, String textPair, boolean addSpecialTokens) { + public Encoding encode( + String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) { long encoding = TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -243,7 +249,7 @@ public Encoding encode(String text, String textPair, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence */ public Encoding encode(String text, String textPair) { - return encode(text, textPair, addSpecialTokens); + return encode(text, textPair, addSpecialTokens, withOverflowingTokens); } /** @@ -252,11 +258,13 @@ public Encoding encode(String text, String textPair) { * @param inputs the input sentences * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentences */ - public Encoding encode(List inputs, boolean addSpecialTokens) { + public Encoding encode( + List inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { String[] array = inputs.toArray(Utils.EMPTY_ARRAY); - return encode(array, addSpecialTokens); + return encode(array, addSpecialTokens, withOverflowingTokens); } /** @@ -266,7 +274,7 @@ public Encoding encode(List inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentences */ public Encoding encode(List inputs) { - return encode(inputs, addSpecialTokens); + return encode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -275,11 +283,13 @@ public Encoding encode(List inputs) { * @param inputs the input sentences * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentences */ - public Encoding encode(String[] inputs, boolean addSpecialTokens) { + public Encoding encode( + String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens); - return toEncoding(encoding); + return toEncoding(encoding, withOverflowingTokens); } /** @@ -289,7 +299,7 @@ public Encoding encode(String[] inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentences */ public Encoding encode(String[] inputs) { - return encode(inputs, addSpecialTokens); + return encode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -298,11 +308,13 @@ public Encoding encode(String[] inputs) { * @param inputs the batch of input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence in batch */ - public Encoding[] batchEncode(List inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + List inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { String[] array = inputs.toArray(Utils.EMPTY_ARRAY); - return batchEncode(array, addSpecialTokens); + return batchEncode(array, addSpecialTokens, withOverflowingTokens); } /** @@ -312,7 +324,7 @@ public Encoding[] batchEncode(List inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence in batch */ public Encoding[] batchEncode(List inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -321,13 +333,15 @@ public Encoding[] batchEncode(List inputs) { * @param inputs the batch of input sentence * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input sentence in batch */ - public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) { long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens); Encoding[] ret = new Encoding[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - ret[i] = toEncoding(encodings[i]); + ret[i] = toEncoding(encodings[i], withOverflowingTokens); } return ret; } @@ -339,7 +353,7 @@ public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) { * @return the {@code Encoding} of the input sentence in batch */ public Encoding[] batchEncode(String[] inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -348,9 +362,13 @@ public Encoding[] batchEncode(String[] inputs) { * @param inputs the batch of input text pair * @param addSpecialTokens whether to encode the sequence with special tokens relative to their * model + * @param withOverflowingTokens whether to return overflowing tokens * @return the {@code Encoding} of the input text pair in batch */ - public Encoding[] batchEncode(PairList inputs, boolean addSpecialTokens) { + public Encoding[] batchEncode( + PairList inputs, + boolean addSpecialTokens, + boolean withOverflowingTokens) { String[] text = inputs.keyArray(Utils.EMPTY_ARRAY); String[] textPair = inputs.valueArray(Utils.EMPTY_ARRAY); long[] encodings = @@ -358,7 +376,7 @@ public Encoding[] batchEncode(PairList inputs, boolean addSpecia getHandle(), text, textPair, addSpecialTokens); Encoding[] ret = new Encoding[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - ret[i] = toEncoding(encodings[i]); + ret[i] = toEncoding(encodings[i], withOverflowingTokens); } return ret; } @@ -370,7 +388,7 @@ public Encoding[] batchEncode(PairList inputs, boolean addSpecia * @return the {@code Encoding} of the input text pair in batch */ public Encoding[] batchEncode(PairList inputs) { - return batchEncode(inputs, addSpecialTokens); + return batchEncode(inputs, addSpecialTokens, withOverflowingTokens); } /** @@ -503,7 +521,7 @@ private void updateTruncationAndPadding() { } } - private Encoding toEncoding(long encoding) { + private Encoding toEncoding(long encoding, boolean withOverflowingTokens) { long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding); long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding); String[] tokens = TokenizersLibrary.LIB.getTokens(encoding); @@ -511,11 +529,17 @@ private Encoding toEncoding(long encoding) { long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding); long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding); CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding); - long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding); - Encoding[] overflowing = new Encoding[overflowingHandles.length]; - for (int i = 0; i < overflowingHandles.length; ++i) { - overflowing[i] = toEncoding(overflowingHandles[i]); + Encoding[] overflowing; + if (withOverflowingTokens) { + long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding); + + overflowing = new Encoding[overflowingHandles.length]; + for (int i = 0; i < overflowingHandles.length; ++i) { + overflowing[i] = toEncoding(overflowingHandles[i], true); + } + } else { + overflowing = new Encoding[0]; } TokenizersLibrary.LIB.deleteEncoding(encoding); @@ -651,6 +675,17 @@ public Builder optAddSpecialTokens(boolean addSpecialTokens) { return this; } + /** + * Sets if add special tokens. + * + * @param withOverflowingTokens true to return overflowing tokens + * @return this builder + */ + public Builder optWithOverflowingTokens(boolean withOverflowingTokens) { + options.put("withOverflowingTokens", String.valueOf(withOverflowingTokens)); + return this; + } + /** * Enables or Disables default truncation behavior for the tokenizer. * diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java index 43b120cac43..9a4ccba42b5 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java @@ -37,7 +37,7 @@ public class FillMaskBatchTranslator implements NoBatchifyTranslator { this.maskToken = maskToken; this.topK = topK; this.batchifier = batchifier; - Encoding encoding = tokenizer.encode(maskToken, false); + Encoding encoding = tokenizer.encode(maskToken, false, false); maskTokenId = encoding.getIds()[0]; } diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java index 2bc30d4bddf..8b5d6d57557 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/HuggingFaceTokenizerTest.java @@ -300,6 +300,7 @@ public void testTruncationStride() throws IOException { HuggingFaceTokenizer.builder() .optTokenizerName("bert-base-cased") .optAddSpecialTokens(false) + .optWithOverflowingTokens(true) .optTruncation(true) .optMaxLength(3) .optStride(1) @@ -322,6 +323,7 @@ public void testTruncationStride() throws IOException { HuggingFaceTokenizer.builder() .optTokenizerName("bert-base-cased") .optAddSpecialTokens(false) + .optWithOverflowingTokens(true) .optTruncation(true) .optMaxLength(8) .optStride(2) @@ -458,13 +460,13 @@ public void testBatchProcessing() throws IOException { Assert.assertEquals(outputs, outputsWithSpecialTokens); // encode with special tokens, decode with special tokens - encodings = tokenizer.batchEncode(inputs, true); + encodings = tokenizer.batchEncode(inputs, true, false); batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); outputs = tokenizer.batchDecode(batchIds, false); Assert.assertEquals(outputs, outputsWithSpecialTokens); // encode without special tokens, decode without special tokens - encodings = tokenizer.batchEncode(inputs, false); + encodings = tokenizer.batchEncode(inputs, false, false); batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new); outputs = tokenizer.batchDecode(batchIds, true); Assert.assertEquals(outputs, outputsWithoutSpecialTokens);