diff --git a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py index 8601f71..e97bfb9 100644 --- a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py +++ b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Any, cast @@ -15,6 +16,8 @@ from ..nmt_model_factory import NmtModelFactory from ..shared_file_service import SharedFileService +logger = logging.getLogger(__name__) + class HuggingFaceNmtModelFactory(NmtModelFactory): def __init__(self, config: Any, shared_file_service: SharedFileService) -> None: @@ -67,7 +70,11 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer: add_unk_trg_tokens=self._config.huggingface.tokenizer.add_unk_trg_tokens, ) - def create_engine(self) -> TranslationEngine: + def create_engine(self, half_previous_batch_size=False) -> TranslationEngine: + if half_previous_batch_size: + self._config.huggingface.generate_params.batch_size = max( + self._config.huggingface.generate_params.batch_size // 2, 1 + ) return HuggingFaceNmtEngine( self._model, src_lang=self._config.src_lang, @@ -76,6 +83,7 @@ def create_engine(self) -> TranslationEngine: num_beams=self._config.huggingface.generate_params.num_beams, batch_size=self._config.huggingface.generate_params.batch_size, truncation=TruncationStrategy.LONGEST_FIRST, + oom_batch_size_backoff_mult=self._config.huggingface.generate_params.oom_batch_size_backoff_mult, ) def save_model(self) -> None: diff --git a/machine/jobs/nmt_model_factory.py b/machine/jobs/nmt_model_factory.py index 850280b..6161320 100644 --- a/machine/jobs/nmt_model_factory.py +++ b/machine/jobs/nmt_model_factory.py @@ -29,7 +29,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer: ... @abstractmethod - def create_engine(self) -> TranslationEngine: + def create_engine(self, half_previous_batch_size=False) -> TranslationEngine: ... @abstractmethod diff --git a/machine/jobs/settings.yaml b/machine/jobs/settings.yaml index e1a692b..6c00382 100644 --- a/machine/jobs/settings.yaml +++ b/machine/jobs/settings.yaml @@ -20,6 +20,7 @@ default: device: 0 num_beams: 2 batch_size: 16 + oom_batch_size_backoff_mult: 0.5 tokenizer: add_unk_src_tokens: true add_unk_trg_tokens: true @@ -34,4 +35,4 @@ staging: huggingface: parent_model_name: facebook/nllb-200-distilled-600M generate_params: - num_beams: 1 + num_beams: 1 \ No newline at end of file diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index e730d16..091f2a0 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -1,6 +1,7 @@ from __future__ import annotations import gc +import logging from math import exp, prod from typing import Any, Iterable, List, Sequence, Tuple, Union, cast @@ -17,29 +18,36 @@ from ..translation_sources import TranslationSources from ..word_alignment_matrix import WordAlignmentMatrix +logger = logging.getLogger(__name__) + class HuggingFaceNmtEngine(TranslationEngine): def __init__( self, model: Union[PreTrainedModel, StrPath, str], + oom_batch_size_backoff_mult: float = 1.0, **pipeline_kwargs, ) -> None: - if isinstance(model, PreTrainedModel): - model.eval() + self._model = model + self._pipeline_kwargs = pipeline_kwargs + if isinstance(self._model, PreTrainedModel): + self._model.eval() else: - model_config = AutoConfig.from_pretrained(str(model), label2id={}, id2label={}, num_labels=0) - model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(model), config=model_config)) - self._tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True) + model_config = AutoConfig.from_pretrained(str(self._model), label2id={}, id2label={}, num_labels=0) + self._model = cast( + PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(self._model), config=model_config) + ) + self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True) - src_lang = pipeline_kwargs.get("src_lang") - tgt_lang = pipeline_kwargs.get("tgt_lang") + src_lang = self._pipeline_kwargs.get("src_lang") + tgt_lang = self._pipeline_kwargs.get("tgt_lang") if ( src_lang is not None and tgt_lang is not None - and "prefix" not in pipeline_kwargs - and (model.name_or_path.startswith("t5-") or model.name_or_path.startswith("google/mt5-")) + and "prefix" not in self._pipeline_kwargs + and (self._model.name_or_path.startswith("t5-") or self._model.name_or_path.startswith("google/mt5-")) ): - pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: " + self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: " else: additional_special_tokens = self._tokenizer.additional_special_tokens if ( @@ -56,10 +64,15 @@ def __init__( ): raise ValueError(f"'{tgt_lang}' is not a valid language code.") + self._batch_size = int(self._pipeline_kwargs.pop("batch_size", 1)) + + self._oom_batch_size_backoff_mult = oom_batch_size_backoff_mult + self._pipeline = _TranslationPipeline( - model=model, + model=self._model, tokenizer=self._tokenizer, - **pipeline_kwargs, + batch_size=self._batch_size, + **self._pipeline_kwargs, ) def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult: @@ -73,6 +86,32 @@ def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequ def translate_n_batch( self, n: int, segments: Sequence[Union[str, Sequence[str]]] + ) -> Sequence[Sequence[TranslationResult]]: + while True: + if type(segments) is str: + segments = [segments] + else: + segments = [segment for segment in segments] + outer_batch_size = len(segments) + all_results: List[Sequence[TranslationResult]] = [] + try: + for step in range(0, outer_batch_size, self._batch_size): + all_results.extend(self._try_translate_n_batch(n, segments[step : step + self._batch_size])) + return all_results + except torch.cuda.OutOfMemoryError: # type: ignore[reportGeneralTypeIssues] + if self._oom_batch_size_backoff_mult >= 0.9999 or self._batch_size <= 1: + raise + self._batch_size = max(int(round(self._batch_size * self._oom_batch_size_backoff_mult)), 1) + logger.warning(f"Out of memory error caught. Reducing batch size to {self._batch_size} and retrying.") + self._pipeline = _TranslationPipeline( + model=self._model, + tokenizer=self._tokenizer, + batch_size=self._batch_size, + **self._pipeline_kwargs, + ) + + def _try_translate_n_batch( + self, n: int, segments: Sequence[Union[str, Sequence[str]]] ) -> Sequence[Sequence[TranslationResult]]: all_results: List[List[TranslationResult]] = [] i = 0 diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index ddc8dfc..e4be94a 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -147,6 +147,8 @@ def train( num_labels=0, ) model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(self._model, config=config)) + + logger.info("Initializing tokenizer") tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True) src_lang = self._src_lang @@ -194,6 +196,7 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any: return AutoTokenizer.from_pretrained(str(tokenizer_dir), use_fast=True) if self._add_unk_src_tokens or self._add_unk_trg_tokens: + logger.info("Checking for missing tokens") if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( f"Tokenizer can not be updated from default configuration: \ @@ -234,6 +237,7 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str): tokenizer.id_to_lang_token[lang_id] = lang_code if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS): + logger.info("Add new language codes as tokens") if self._src_lang is not None: add_lang_code_to_tokenizer(tokenizer, self._src_lang) if self._tgt_lang is not None: @@ -309,6 +313,7 @@ def preprocess_function(examples): model_inputs["labels"] = labels["input_ids"] return model_inputs + logger.info("Run tokenizer") train_dataset = train_dataset.map( preprocess_function, batched=True, @@ -339,17 +344,21 @@ def preprocess_function(examples): ], ) + logger.info("Train NMT model") ckpt = None if self._training_args.resume_from_checkpoint is not None: ckpt = self._training_args.resume_from_checkpoint elif last_checkpoint is not None: ckpt = last_checkpoint - train_result = self._trainer.train(resume_from_checkpoint=ckpt) + train_result = self._trainer.train( + resume_from_checkpoint=ckpt, + ) self._metrics = train_result.metrics self._metrics["train_samples"] = len(train_dataset) self._trainer.log_metrics("train", self._metrics) + logger.info("Model training finished") def save(self) -> None: if self._trainer is None: diff --git a/machine/translation/translation_engine.py b/machine/translation/translation_engine.py index 56bf9ce..a152e2a 100644 --- a/machine/translation/translation_engine.py +++ b/machine/translation/translation_engine.py @@ -2,12 +2,12 @@ from abc import abstractmethod from types import TracebackType -from typing import ContextManager, Optional, Sequence, Type, Union +from typing import Optional, Sequence, Type, Union from .translation_result import TranslationResult -class TranslationEngine(ContextManager["TranslationEngine"]): +class TranslationEngine: @abstractmethod def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult: ...