Skip to content

Commit

Permalink
Set HuggingFaceNmtEngine to not truncate by default
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit committed Oct 20, 2023
1 parent c565239 commit 4023940
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from transformers import AutoConfig, AutoModelForSeq2SeqLM, HfArgumentParser, PreTrainedModel, Seq2SeqTrainingArguments
from transformers.integrations import ClearMLCallback
from transformers.tokenization_utils import TruncationStrategy

from ...corpora.parallel_text_corpus import ParallelTextCorpus
from ...corpora.text_corpus import TextCorpus
Expand Down Expand Up @@ -72,6 +73,7 @@ def create_engine(self) -> TranslationEngine:
device=self._config.huggingface.generate_params.device,
num_beams=self._config.huggingface.generate_params.num_beams,
batch_size=self._config.huggingface.generate_params.batch_size,
truncation=TruncationStrategy.LONGEST_FIRST,
)

def save_model(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def close(self) -> None:


class _TranslationPipeline(TranslationPipeline):
def preprocess(self, *args, truncation=TruncationStrategy.LONGEST_FIRST, src_lang=None, tgt_lang=None):
def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None):
if self.tokenizer is None:
raise RuntimeError("No tokenizer is specified.")
sentences = [
Expand Down

0 comments on commit 4023940

Please sign in to comment.