Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM error fixing #58

Merged
merged 9 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import Any, cast

Expand All @@ -9,12 +10,14 @@
from ...corpora.text_corpus import TextCorpus
from ...translation.huggingface.hugging_face_nmt_engine import HuggingFaceNmtEngine
from ...translation.huggingface.hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer
from ...translation.nmt_translation_engine import NmtTranslationEngine
from ...translation.null_trainer import NullTrainer
from ...translation.trainer import Trainer
from ...translation.translation_engine import TranslationEngine
from ..nmt_model_factory import NmtModelFactory
from ..shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class HuggingFaceNmtModelFactory(NmtModelFactory):
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
Expand Down Expand Up @@ -67,7 +70,11 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
add_unk_trg_tokens=self._config.huggingface.tokenizer.add_unk_trg_tokens,
)

def create_engine(self) -> TranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> NmtTranslationEngine:
if half_previous_batch_size:
self._config.huggingface.generate_params.batch_size = max(
self._config.huggingface.generate_params.batch_size // 2, 1
)
return HuggingFaceNmtEngine(
self._model,
src_lang=self._config.src_lang,
Expand Down
46 changes: 34 additions & 12 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional, Sequence

from ..corpora.corpora_utils import batch
from ..translation.translation_engine import TranslationEngine
from ..translation.nmt_translation_engine import NmtTranslationEngine
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
from ..utils.progress_status import ProgressStatus
from .nmt_model_factory import NmtModelFactory
Expand Down Expand Up @@ -81,26 +81,48 @@ def run(
inference_step_count = sum(1 for _ in src_pretranslations)
with ExitStack() as stack:
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
engine = stack.enter_context(self._nmt_model_factory.create_engine())
src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations())
writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer())
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["batch_size"]
translate_batch = TranslateBatch(stack, self._nmt_model_factory)
for pi_batch in batch(src_pretranslations, batch_size):
if check_canceled is not None:
check_canceled()
_translate_batch(engine, pi_batch, writer)
translate_batch.translate(pi_batch, writer)
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))


def _translate_batch(
engine: TranslationEngine,
batch: Sequence[PretranslationInfo],
writer: PretranslationWriter,
) -> None:
source_segments = [pi["translation"] for pi in batch]
for i, result in enumerate(engine.translate_batch(source_segments)):
batch[i]["translation"] = result.translation
writer.write(batch[i])
batch_divisor = 1


class TranslateBatch:
def __init__(self, stack: ExitStack, nmt_model_factory: NmtModelFactory):
self._stack = stack
self._nmt_model_factory = nmt_model_factory
self._engine: NmtTranslationEngine = self._stack.enter_context(self._nmt_model_factory.create_engine())

def translate(
self,
batch: Sequence[PretranslationInfo],
writer: PretranslationWriter,
) -> None:
while True:
source_segments = [pi["translation"] for pi in batch]
outer_batch_size = len(source_segments)
try:
for step in range(0, outer_batch_size, self._engine.get_batch_size()):
for i, result in enumerate(
self._engine.translate_batch(source_segments[step : step + self._engine.get_batch_size()])
):
batch[i + step]["translation"] = result.translation
for i in range(len(source_segments)):
writer.write(batch[i])
break
except Exception:
logger.info(f"Out of memory error, reducing batch size to {self._engine.get_batch_size() // 2}")
self._engine = self._stack.enter_context(
self._nmt_model_factory.create_engine(half_previous_batch_size=True)
)
4 changes: 2 additions & 2 deletions machine/jobs/nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from ..corpora.parallel_text_corpus import ParallelTextCorpus
from ..corpora.text_corpus import TextCorpus
from ..translation.nmt_translation_engine import NmtTranslationEngine
from ..translation.trainer import Trainer
from ..translation.translation_engine import TranslationEngine


class NmtModelFactory(ABC):
Expand All @@ -29,7 +29,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
...

@abstractmethod
def create_engine(self) -> TranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> NmtTranslationEngine:
...

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions machine/jobs/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ default:
max_steps: 20000
data_dir: ~/machine
batch_size: 1024
oom_batch_size_backoff_multiplier: 0.5
huggingface:
parent_model_name: facebook/nllb-200-distilled-1.3B
train_params:
Expand Down Expand Up @@ -33,5 +34,7 @@ staging:
max_steps: 10
huggingface:
parent_model_name: facebook/nllb-200-distilled-600M
train_params:
group_by_length: false
generate_params:
num_beams: 1
78 changes: 64 additions & 14 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import gc
import logging
from math import exp, prod
from typing import Any, Iterable, List, Sequence, Tuple, Union, cast

Expand All @@ -11,35 +12,41 @@

from ...annotations.range import Range
from ...utils.typeshed import StrPath
from ..translation_engine import TranslationEngine
from ..nmt_translation_engine import NmtTranslationEngine
from ..translation_result import TranslationResult
from ..translation_result_builder import TranslationResultBuilder
from ..translation_sources import TranslationSources
from ..word_alignment_matrix import WordAlignmentMatrix

logger = logging.getLogger(__name__)

class HuggingFaceNmtEngine(TranslationEngine):

class HuggingFaceNmtEngine(NmtTranslationEngine):
def __init__(
self,
model: Union[PreTrainedModel, StrPath, str],
**pipeline_kwargs,
) -> None:
if isinstance(model, PreTrainedModel):
model.eval()
self._model = model
self._pipeline_kwargs = pipeline_kwargs
if isinstance(self._model, PreTrainedModel):
self._model.eval()
else:
model_config = AutoConfig.from_pretrained(str(model), label2id={}, id2label={}, num_labels=0)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(model), config=model_config))
self._tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)
model_config = AutoConfig.from_pretrained(str(self._model), label2id={}, id2label={}, num_labels=0)
self._model = cast(
PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(self._model), config=model_config)
)
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)

src_lang = pipeline_kwargs.get("src_lang")
tgt_lang = pipeline_kwargs.get("tgt_lang")
src_lang = self._pipeline_kwargs.get("src_lang")
tgt_lang = self._pipeline_kwargs.get("tgt_lang")
if (
src_lang is not None
and tgt_lang is not None
and "prefix" not in pipeline_kwargs
and (model.name_or_path.startswith("t5-") or model.name_or_path.startswith("google/mt5-"))
and "prefix" not in self._pipeline_kwargs
and (self._model.name_or_path.startswith("t5-") or self._model.name_or_path.startswith("google/mt5-"))
):
pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
else:
additional_special_tokens = self._tokenizer.additional_special_tokens
if (
Expand All @@ -56,10 +63,20 @@ def __init__(
):
raise ValueError(f"'{tgt_lang}' is not a valid language code.")

batch_size = self._pipeline_kwargs.pop("batch_size")
if batch_size is not None:
self._batch_size = int(batch_size) # type: ignore[assignment]
else:
self._batch_size = 16

# If not set, default to not backing off (1.0).
self._oom_batch_size_backoff_multiplier = self._pipeline_kwargs.pop("oom_batch_size_backoff_multiplier", 1.0)

self._pipeline = _TranslationPipeline(
model=model,
model=self._model,
tokenizer=self._tokenizer,
**pipeline_kwargs,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)

def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult:
Expand All @@ -71,8 +88,41 @@ def translate_n(self, n: int, segment: Union[str, Sequence[str]]) -> Sequence[Tr
def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequence[TranslationResult]:
return [results[0] for results in self.translate_n_batch(1, segments)]

def get_batch_size(self) -> int:
return self._batch_size

def translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
while True:
if type(segments) is str:
segments = [segments]
else:
segments = [segment for segment in segments]
outer_batch_size = len(segments)
all_results: List[Sequence[TranslationResult]] = []
try:
for step in range(0, outer_batch_size, self._batch_size):
all_results.extend(self._try_translate_n_batch(n, segments[step : step + self._batch_size]))
return all_results
except Exception as e:
if self._oom_batch_size_backoff_multiplier >= 0.9999:
raise Exception(
"Likely an Out of Memory Error. Change oom_batch_size_backoff_multiplier to < 1 to gracefuly handle these type of errors."
) from e
self._batch_size = max(int(round(self._batch_size * self._oom_batch_size_backoff_multiplier)), 1)
logger.info(
f"Out of memory error caught, reducing batch size to {self._batch_size}. Remaking translation pipeline."
)
self._pipeline = _TranslationPipeline(
model=self._model,
tokenizer=self._tokenizer,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)

def _try_translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
all_results: List[List[TranslationResult]] = []
i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def train(
# Set seed before initializing model.
set_seed(self._training_args.seed)

logger.info("Initializing tokenizer.")
if isinstance(self._model, PreTrainedModel):
model = self._model
self._original_use_cache = model.config.use_cache
Expand Down Expand Up @@ -193,6 +194,7 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
logger.info(f"Added {len(missing_tokens)} tokens to the tokenizer: {missing_tokens}")
return AutoTokenizer.from_pretrained(str(tokenizer_dir), use_fast=True)

logger.info("Checking for missing tokens.")
if self._add_unk_src_tokens or self._add_unk_trg_tokens:
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
Expand Down Expand Up @@ -233,6 +235,7 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
tokenizer.lang_token_to_id[lang_code] = lang_id
tokenizer.id_to_lang_token[lang_id] = lang_code

logger.info("Add new language codes as tokens.")
if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
if self._src_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._src_lang)
Expand Down Expand Up @@ -309,6 +312,7 @@ def preprocess_function(examples):
model_inputs["labels"] = labels["input_ids"]
return model_inputs

logger.info("Run tokenizer.")
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
Expand Down Expand Up @@ -339,17 +343,21 @@ def preprocess_function(examples):
],
)

logger.info("Train NMT model.")
ckpt = None
if self._training_args.resume_from_checkpoint is not None:
ckpt = self._training_args.resume_from_checkpoint
elif last_checkpoint is not None:
ckpt = last_checkpoint
train_result = self._trainer.train(resume_from_checkpoint=ckpt)
train_result = self._trainer.train(
resume_from_checkpoint=ckpt,
)

self._metrics = train_result.metrics
self._metrics["train_samples"] = len(train_dataset)

self._trainer.log_metrics("train", self._metrics)
logger.info("Model training finished.")

def save(self) -> None:
if self._trainer is None:
Expand Down
12 changes: 12 additions & 0 deletions machine/translation/nmt_translation_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

from abc import abstractmethod
from typing import ContextManager

from .translation_engine import TranslationEngine


class NmtTranslationEngine(TranslationEngine, ContextManager["NmtTranslationEngine"]):
@abstractmethod
def get_batch_size(self) -> int:
...
4 changes: 2 additions & 2 deletions machine/translation/translation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from abc import abstractmethod
from types import TracebackType
from typing import ContextManager, Optional, Sequence, Type, Union
from typing import Optional, Sequence, Type, Union

from .translation_result import TranslationResult


class TranslationEngine(ContextManager["TranslationEngine"]):
class TranslationEngine:
@abstractmethod
def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult:
...
Expand Down
5 changes: 3 additions & 2 deletions tests/jobs/test_nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from machine.corpora import DictionaryTextCorpus
from machine.jobs import NmtEngineBuildJob, NmtModelFactory, PretranslationInfo, PretranslationWriter, SharedFileService
from machine.translation import Phrase, Trainer, TrainStats, TranslationResult, TranslationSources, WordAlignmentMatrix
from machine.translation.translation_engine import TranslationEngine
from machine.translation.nmt_translation_engine import NmtTranslationEngine
from machine.utils import CanceledError, ContextManagedGenerator


Expand Down Expand Up @@ -45,8 +45,9 @@ def __init__(self, decoy: Decoy) -> None:
stats.metrics["bleu"] = 30.0
decoy.when(self.model_trainer.stats).then_return(stats)

self.engine = decoy.mock(cls=TranslationEngine)
self.engine = decoy.mock(cls=NmtTranslationEngine)
decoy.when(self.engine.__enter__()).then_return(self.engine)
decoy.when(self.engine.get_batch_size()).then_return(16)
decoy.when(self.engine.translate_batch(matchers.Anything())).then_return(
[
TranslationResult(
Expand Down
Loading