diff --git a/src/tokenizers.js b/src/tokenizers.js index bb9bd1261..4f8a1b03a 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -1384,10 +1384,10 @@ class WhitespaceSplit extends PreTokenizer { class PreTrainedTokenizer extends Callable { /** - * Create a new PreTrainedTokenizer instance. - * @param {Object} tokenizerJSON - The JSON of the tokenizer. - * @param {Object} tokenizerConfig - The config of the tokenizer. - */ + * Create a new PreTrainedTokenizer instance. + * @param {Object} tokenizerJSON - The JSON of the tokenizer. + * @param {Object} tokenizerConfig - The config of the tokenizer. + */ constructor(tokenizerJSON, tokenizerConfig) { super(); @@ -2279,7 +2279,52 @@ class WhisperTokenizer extends PreTrainedTokenizer { } class CodeGenTokenizer extends PreTrainedTokenizer { } class CLIPTokenizer extends PreTrainedTokenizer { } -class MarianTokenizer extends PreTrainedTokenizer { } +class MarianTokenizer extends PreTrainedTokenizer { + /** + * Create a new MarianTokenizer instance. + * @param {Object} tokenizerJSON - The JSON of the tokenizer. + * @param {Object} tokenizerConfig - The config of the tokenizer. + */ + constructor(tokenizerJSON, tokenizerConfig) { + super(tokenizerJSON, tokenizerConfig); + + this.languageRegex = /^(>>\w+<<)\s*/g; + + this.supported_language_codes = this.model.vocab.filter( + x => this.languageRegex.test(x) + ); + } + + /** + * Encodes a single text. Overriding this method is necessary since the language codes + * must be removed before encoding with sentencepiece model. + * @see https://github.com/huggingface/transformers/blob/12d51db243a00726a548a43cc333390ebae731e3/src/transformers/models/marian/tokenization_marian.py#L204-L213 + * + * @param {string|null} text - The text to encode. + * @returns {Array} The encoded tokens. + */ + _encode_text(text) { + if (text === null) return null; + + // Check if text starts with language code: + let [matchInfo, ...remainder] = text.trim().split(this.languageRegex); + + if (remainder.length === 0) { + // No language code, encode normally + return super._encode_text(matchInfo); + + } else if (remainder.length === 2) { + // Text starts with language code, so we do not encode it with sentencepiece. + let [language, text] = remainder; + + if (!this.supported_language_codes.includes(language)) { + console.warn(`Unsupported language code "${language}" detected, which may lead to unexpected behavior. Should be one of: ${JSON.stringify(this.supported_language_codes)}`) + } + return [language, ...super._encode_text(text)] + } + } + +} /** * A trie structure to efficiently store and search for strings.