Skip to content

Commit

Permalink
Fix tokenization for multilingual Helsinki-NLP models (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Apr 6, 2023
1 parent 682f7a1 commit 12163dd
Showing 1 changed file with 50 additions and 5 deletions.
55 changes: 50 additions & 5 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -1384,10 +1384,10 @@ class WhitespaceSplit extends PreTokenizer {

class PreTrainedTokenizer extends Callable {
/**
* Create a new PreTrainedTokenizer instance.
* @param {Object} tokenizerJSON - The JSON of the tokenizer.
* @param {Object} tokenizerConfig - The config of the tokenizer.
*/
* Create a new PreTrainedTokenizer instance.
* @param {Object} tokenizerJSON - The JSON of the tokenizer.
* @param {Object} tokenizerConfig - The config of the tokenizer.
*/
constructor(tokenizerJSON, tokenizerConfig) {
super();

Expand Down Expand Up @@ -2279,7 +2279,52 @@ class WhisperTokenizer extends PreTrainedTokenizer {
}
class CodeGenTokenizer extends PreTrainedTokenizer { }
class CLIPTokenizer extends PreTrainedTokenizer { }
class MarianTokenizer extends PreTrainedTokenizer { }
class MarianTokenizer extends PreTrainedTokenizer {
/**
* Create a new MarianTokenizer instance.
* @param {Object} tokenizerJSON - The JSON of the tokenizer.
* @param {Object} tokenizerConfig - The config of the tokenizer.
*/
constructor(tokenizerJSON, tokenizerConfig) {
super(tokenizerJSON, tokenizerConfig);

this.languageRegex = /^(>>\w+<<)\s*/g;

this.supported_language_codes = this.model.vocab.filter(
x => this.languageRegex.test(x)
);
}

/**
* Encodes a single text. Overriding this method is necessary since the language codes
* must be removed before encoding with sentencepiece model.
* @see https://github.com/huggingface/transformers/blob/12d51db243a00726a548a43cc333390ebae731e3/src/transformers/models/marian/tokenization_marian.py#L204-L213
*
* @param {string|null} text - The text to encode.
* @returns {Array} The encoded tokens.
*/
_encode_text(text) {
if (text === null) return null;

// Check if text starts with language code:
let [matchInfo, ...remainder] = text.trim().split(this.languageRegex);

if (remainder.length === 0) {
// No language code, encode normally
return super._encode_text(matchInfo);

} else if (remainder.length === 2) {
// Text starts with language code, so we do not encode it with sentencepiece.
let [language, text] = remainder;

if (!this.supported_language_codes.includes(language)) {
console.warn(`Unsupported language code "${language}" detected, which may lead to unexpected behavior. Should be one of: ${JSON.stringify(this.supported_language_codes)}`)
}
return [language, ...super._encode_text(text)]
}
}

}

/**
* A trie structure to efficiently store and search for strings.
Expand Down

0 comments on commit 12163dd

Please sign in to comment.