Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM error fixing #58

Merged
merged 9 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import Any, cast

Expand All @@ -15,6 +16,8 @@
from ..nmt_model_factory import NmtModelFactory
from ..shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class HuggingFaceNmtModelFactory(NmtModelFactory):
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
Expand Down Expand Up @@ -67,7 +70,11 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
add_unk_trg_tokens=self._config.huggingface.tokenizer.add_unk_trg_tokens,
)

def create_engine(self) -> TranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
if half_previous_batch_size:
self._config.huggingface.generate_params.batch_size = max(
self._config.huggingface.generate_params.batch_size // 2, 1
)
return HuggingFaceNmtEngine(
self._model,
src_lang=self._config.src_lang,
Expand Down
2 changes: 1 addition & 1 deletion machine/jobs/nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
...

@abstractmethod
def create_engine(self) -> TranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
...

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions machine/jobs/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ default:
max_steps: 20000
data_dir: ~/machine
batch_size: 1024
oom_batch_size_backoff_multiplier: 0.5
huggingface:
parent_model_name: facebook/nllb-200-distilled-1.3B
train_params:
Expand Down
75 changes: 63 additions & 12 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import gc
import logging
from math import exp, prod
from typing import Any, Iterable, List, Sequence, Tuple, Union, cast

Expand All @@ -17,29 +18,35 @@
from ..translation_sources import TranslationSources
from ..word_alignment_matrix import WordAlignmentMatrix

logger = logging.getLogger(__name__)


class HuggingFaceNmtEngine(TranslationEngine):
def __init__(
self,
model: Union[PreTrainedModel, StrPath, str],
**pipeline_kwargs,
) -> None:
if isinstance(model, PreTrainedModel):
model.eval()
self._model = model
self._pipeline_kwargs = pipeline_kwargs
if isinstance(self._model, PreTrainedModel):
self._model.eval()
else:
model_config = AutoConfig.from_pretrained(str(model), label2id={}, id2label={}, num_labels=0)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(model), config=model_config))
self._tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)
model_config = AutoConfig.from_pretrained(str(self._model), label2id={}, id2label={}, num_labels=0)
self._model = cast(
PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(self._model), config=model_config)
)
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)

src_lang = pipeline_kwargs.get("src_lang")
tgt_lang = pipeline_kwargs.get("tgt_lang")
src_lang = self._pipeline_kwargs.get("src_lang")
tgt_lang = self._pipeline_kwargs.get("tgt_lang")
if (
src_lang is not None
and tgt_lang is not None
and "prefix" not in pipeline_kwargs
and (model.name_or_path.startswith("t5-") or model.name_or_path.startswith("google/mt5-"))
and "prefix" not in self._pipeline_kwargs
and (self._model.name_or_path.startswith("t5-") or self._model.name_or_path.startswith("google/mt5-"))
):
pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
else:
additional_special_tokens = self._tokenizer.additional_special_tokens
if (
Expand All @@ -56,10 +63,16 @@ def __init__(
):
raise ValueError(f"'{tgt_lang}' is not a valid language code.")

self._batch_size = int(self._pipeline_kwargs.pop("batch_size", 1))

# If not set, default to not backing off (1.0).
self._oom_batch_size_backoff_multiplier = self._pipeline_kwargs.pop("oom_batch_size_backoff_multiplier", 1.0)

self._pipeline = _TranslationPipeline(
model=model,
model=self._model,
tokenizer=self._tokenizer,
**pipeline_kwargs,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)

def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult:
Expand All @@ -73,6 +86,44 @@ def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequ

def translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
while True:
if type(segments) is str:
segments = [segments]
else:
segments = [segment for segment in segments]
outer_batch_size = len(segments)
all_results: List[Sequence[TranslationResult]] = []
try:
for step in range(0, outer_batch_size, self._batch_size):
all_results.extend(self._try_translate_n_batch(n, segments[step : step + self._batch_size]))
return all_results
except Exception as e:
# The out or memory error is not inherited from
if self._oom_batch_size_backoff_multiplier >= 0.9999:
# FIXME after upgrading to Pytorch 2.1, this should be changed to OutOfMemoryError
# see https://github.com/sillsdev/machine.py/issues/67
raise Exception(
"Likely an Out of Memory Error. Change oom_batch_size_backoff_multiplier "
+ "to < 1 to gracefuly handle these type of errors."
) from e
if self._batch_size == 1:
# Could it be another error?
raise e
self._batch_size = max(int(round(self._batch_size * self._oom_batch_size_backoff_multiplier)), 1)
logger.info(
f"Out of memory error caught with message {e.args[0]}, reducing batch size to {self._batch_size}. "
+ "Remaking translation pipeline."
)
self._pipeline = _TranslationPipeline(
model=self._model,
tokenizer=self._tokenizer,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)

def _try_translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
all_results: List[List[TranslationResult]] = []
i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def train(
num_labels=0,
)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(self._model, config=config))

logger.info("Initializing tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)

src_lang = self._src_lang
Expand Down Expand Up @@ -194,6 +196,7 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
return AutoTokenizer.from_pretrained(str(tokenizer_dir), use_fast=True)

if self._add_unk_src_tokens or self._add_unk_trg_tokens:
logger.info("Checking for missing tokens")
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
f"Tokenizer can not be updated from default configuration: \
Expand Down Expand Up @@ -234,6 +237,7 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
tokenizer.id_to_lang_token[lang_id] = lang_code

if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
logger.info("Add new language codes as tokens")
if self._src_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._src_lang)
if self._tgt_lang is not None:
Expand Down Expand Up @@ -309,6 +313,7 @@ def preprocess_function(examples):
model_inputs["labels"] = labels["input_ids"]
return model_inputs

logger.info("Run tokenizer")
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
Expand Down Expand Up @@ -339,17 +344,21 @@ def preprocess_function(examples):
],
)

logger.info("Train NMT model")
ckpt = None
if self._training_args.resume_from_checkpoint is not None:
ckpt = self._training_args.resume_from_checkpoint
elif last_checkpoint is not None:
ckpt = last_checkpoint
train_result = self._trainer.train(resume_from_checkpoint=ckpt)
train_result = self._trainer.train(
resume_from_checkpoint=ckpt,
)

self._metrics = train_result.metrics
self._metrics["train_samples"] = len(train_dataset)

self._trainer.log_metrics("train", self._metrics)
logger.info("Model training finished")

def save(self) -> None:
if self._trainer is None:
Expand Down
4 changes: 2 additions & 2 deletions machine/translation/translation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from abc import abstractmethod
from types import TracebackType
from typing import ContextManager, Optional, Sequence, Type, Union
from typing import Optional, Sequence, Type, Union

from .translation_result import TranslationResult


class TranslationEngine(ContextManager["TranslationEngine"]):
class TranslationEngine:
@abstractmethod
def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult:
...
Expand Down