Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizer] Not returns overflow tokens by default #2857

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public final class HuggingFaceTokenizer extends NativeResource<Long> implements
private static final Logger logger = LoggerFactory.getLogger(HuggingFaceTokenizer.class);

private boolean addSpecialTokens;
private boolean withOverflowingTokens;
private TruncationStrategy truncation;
private PaddingStrategy padding;
private int maxLength;
Expand All @@ -64,6 +65,8 @@ private HuggingFaceTokenizer(long handle, Map<String, String> options) {
if (options != null) {
val = options.getOrDefault("addSpecialTokens", "true");
addSpecialTokens = Boolean.parseBoolean(val);
val = options.getOrDefault("withOverflowingTokens", "false");
withOverflowingTokens = Boolean.parseBoolean(val);
modelMaxLength = ArgumentsUtil.intValue(options, "modelMaxLength", 512);
if (options.containsKey("truncation")) {
truncation = TruncationStrategy.fromValue(options.get("truncation"));
Expand Down Expand Up @@ -203,11 +206,12 @@ public void close() {
* @param text the input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, boolean addSpecialTokens) {
public Encoding encode(String text, boolean addSpecialTokens, boolean withOverflowingTokens) {
long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens);
return toEncoding(encoding);
return toEncoding(encoding, withOverflowingTokens);
}

/**
Expand All @@ -217,7 +221,7 @@ public Encoding encode(String text, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text) {
return encode(text, addSpecialTokens);
return encode(text, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -227,12 +231,14 @@ public Encoding encode(String text) {
* @param textPair the second input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
public Encoding encode(
String text, String textPair, boolean addSpecialTokens, boolean withOverflowingTokens) {
long encoding =
TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens);
return toEncoding(encoding);
return toEncoding(encoding, withOverflowingTokens);
}

/**
Expand All @@ -243,7 +249,7 @@ public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, String textPair) {
return encode(text, textPair, addSpecialTokens);
return encode(text, textPair, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -252,11 +258,13 @@ public Encoding encode(String text, String textPair) {
* @param inputs the input sentences
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(List<String> inputs, boolean addSpecialTokens) {
public Encoding encode(
List<String> inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
return encode(array, addSpecialTokens);
return encode(array, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -266,7 +274,7 @@ public Encoding encode(List<String> inputs, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(List<String> inputs) {
return encode(inputs, addSpecialTokens);
return encode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -275,11 +283,13 @@ public Encoding encode(List<String> inputs) {
* @param inputs the input sentences
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(String[] inputs, boolean addSpecialTokens) {
public Encoding encode(
String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens);
return toEncoding(encoding);
return toEncoding(encoding, withOverflowingTokens);
}

/**
Expand All @@ -289,7 +299,7 @@ public Encoding encode(String[] inputs, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(String[] inputs) {
return encode(inputs, addSpecialTokens);
return encode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -298,11 +308,13 @@ public Encoding encode(String[] inputs) {
* @param inputs the batch of input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(List<String> inputs, boolean addSpecialTokens) {
public Encoding[] batchEncode(
List<String> inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
String[] array = inputs.toArray(Utils.EMPTY_ARRAY);
return batchEncode(array, addSpecialTokens);
return batchEncode(array, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -312,7 +324,7 @@ public Encoding[] batchEncode(List<String> inputs, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(List<String> inputs) {
return batchEncode(inputs, addSpecialTokens);
return batchEncode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -321,13 +333,15 @@ public Encoding[] batchEncode(List<String> inputs) {
* @param inputs the batch of input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) {
public Encoding[] batchEncode(
String[] inputs, boolean addSpecialTokens, boolean withOverflowingTokens) {
long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens);
Encoding[] ret = new Encoding[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
ret[i] = toEncoding(encodings[i]);
ret[i] = toEncoding(encodings[i], withOverflowingTokens);
}
return ret;
}
Expand All @@ -339,7 +353,7 @@ public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) {
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(String[] inputs) {
return batchEncode(inputs, addSpecialTokens);
return batchEncode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand All @@ -348,17 +362,21 @@ public Encoding[] batchEncode(String[] inputs) {
* @param inputs the batch of input text pair
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @param withOverflowingTokens whether to return overflowing tokens
* @return the {@code Encoding} of the input text pair in batch
*/
public Encoding[] batchEncode(PairList<String, String> inputs, boolean addSpecialTokens) {
public Encoding[] batchEncode(
PairList<String, String> inputs,
boolean addSpecialTokens,
boolean withOverflowingTokens) {
String[] text = inputs.keyArray(Utils.EMPTY_ARRAY);
String[] textPair = inputs.valueArray(Utils.EMPTY_ARRAY);
long[] encodings =
TokenizersLibrary.LIB.batchEncodePair(
getHandle(), text, textPair, addSpecialTokens);
Encoding[] ret = new Encoding[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
ret[i] = toEncoding(encodings[i]);
ret[i] = toEncoding(encodings[i], withOverflowingTokens);
}
return ret;
}
Expand All @@ -370,7 +388,7 @@ public Encoding[] batchEncode(PairList<String, String> inputs, boolean addSpecia
* @return the {@code Encoding} of the input text pair in batch
*/
public Encoding[] batchEncode(PairList<String, String> inputs) {
return batchEncode(inputs, addSpecialTokens);
return batchEncode(inputs, addSpecialTokens, withOverflowingTokens);
}

/**
Expand Down Expand Up @@ -503,19 +521,25 @@ private void updateTruncationAndPadding() {
}
}

private Encoding toEncoding(long encoding) {
private Encoding toEncoding(long encoding, boolean withOverflowingTokens) {
long[] ids = TokenizersLibrary.LIB.getTokenIds(encoding);
long[] typeIds = TokenizersLibrary.LIB.getTypeIds(encoding);
String[] tokens = TokenizersLibrary.LIB.getTokens(encoding);
long[] wordIds = TokenizersLibrary.LIB.getWordIds(encoding);
long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding);
long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);

Encoding[] overflowing = new Encoding[overflowingHandles.length];
for (int i = 0; i < overflowingHandles.length; ++i) {
overflowing[i] = toEncoding(overflowingHandles[i]);
Encoding[] overflowing;
if (withOverflowingTokens) {
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);

overflowing = new Encoding[overflowingHandles.length];
for (int i = 0; i < overflowingHandles.length; ++i) {
overflowing[i] = toEncoding(overflowingHandles[i], true);
}
} else {
overflowing = new Encoding[0];
}

TokenizersLibrary.LIB.deleteEncoding(encoding);
Expand Down Expand Up @@ -651,6 +675,17 @@ public Builder optAddSpecialTokens(boolean addSpecialTokens) {
return this;
}

/**
* Sets if add special tokens.
*
* @param withOverflowingTokens true to return overflowing tokens
* @return this builder
*/
public Builder optWithOverflowingTokens(boolean withOverflowingTokens) {
options.put("withOverflowingTokens", String.valueOf(withOverflowingTokens));
return this;
}

/**
* Enables or Disables default truncation behavior for the tokenizer.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class FillMaskBatchTranslator implements NoBatchifyTranslator<String[], C
this.maskToken = maskToken;
this.topK = topK;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false);
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class FillMaskTranslator implements Translator<String, Classifications> {
this.maskToken = maskToken;
this.topK = topK;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false);
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ public void testTruncationStride() throws IOException {
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optAddSpecialTokens(false)
.optWithOverflowingTokens(true)
.optTruncation(true)
.optMaxLength(3)
.optStride(1)
Expand All @@ -322,6 +323,7 @@ public void testTruncationStride() throws IOException {
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optAddSpecialTokens(false)
.optWithOverflowingTokens(true)
.optTruncation(true)
.optMaxLength(8)
.optStride(2)
Expand Down Expand Up @@ -458,13 +460,13 @@ public void testBatchProcessing() throws IOException {
Assert.assertEquals(outputs, outputsWithSpecialTokens);

// encode with special tokens, decode with special tokens
encodings = tokenizer.batchEncode(inputs, true);
encodings = tokenizer.batchEncode(inputs, true, false);
batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new);
outputs = tokenizer.batchDecode(batchIds, false);
Assert.assertEquals(outputs, outputsWithSpecialTokens);

// encode without special tokens, decode without special tokens
encodings = tokenizer.batchEncode(inputs, false);
encodings = tokenizer.batchEncode(inputs, false, false);
batchIds = Arrays.stream(encodings).map(Encoding::getIds).toArray(long[][]::new);
outputs = tokenizer.batchDecode(batchIds, true);
Assert.assertEquals(outputs, outputsWithoutSpecialTokens);
Expand Down
Loading