diff --git a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py index e0e9cb6..ce44c75 100644 --- a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py +++ b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py @@ -3,6 +3,7 @@ from transformers import AutoConfig, AutoModelForSeq2SeqLM, HfArgumentParser, PreTrainedModel, Seq2SeqTrainingArguments from transformers.integrations import ClearMLCallback +from transformers.tokenization_utils import TruncationStrategy from ...corpora.parallel_text_corpus import ParallelTextCorpus from ...corpora.text_corpus import TextCorpus @@ -72,6 +73,7 @@ def create_engine(self) -> TranslationEngine: device=self._config.huggingface.generate_params.device, num_beams=self._config.huggingface.generate_params.num_beams, batch_size=self._config.huggingface.generate_params.batch_size, + truncation=TruncationStrategy.LONGEST_FIRST, ) def save_model(self) -> None: diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 7ae01fa..2f96182 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -93,7 +93,7 @@ def close(self) -> None: class _TranslationPipeline(TranslationPipeline): - def preprocess(self, *args, truncation=TruncationStrategy.LONGEST_FIRST, src_lang=None, tgt_lang=None): + def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None): if self.tokenizer is None: raise RuntimeError("No tokenizer is specified.") sentences = [