From 3195281682b6f2b7d64254aae67edd396966e816 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 18 Apr 2024 19:28:41 -0700 Subject: [PATCH 1/2] add it --- llmfoundry/callbacks/hf_checkpointer.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index baa72a7f66..3716a70f05 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -12,7 +12,7 @@ import time from multiprocessing.context import SpawnProcess from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -273,6 +273,24 @@ def _all_child_processes_done(self) -> bool: dist.all_reduce(x, reduce_operation='MAX') return x.item() == 0 + def transform_model_and_tokenizer( + self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase + ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + """Transform the model and tokenizer before saving. + + This allows a subclass to modify the model and tokenizer before saving. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be transformed. + + Returns: + Tuple[PreTrainedModel, PreTrainedTokenizerBase]: The transformed model and tokenizer. + """ + + return model, tokenizer + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -405,6 +423,10 @@ def dtensor_to_tensor_hook( new_model_instance.load_state_dict(state_dict, assign=True) del state_dict + # Transform the model and tokenizer before saving + new_model_instance, original_tokenizer = self.transform_model_and_tokenizer( + new_model_instance, original_tokenizer) + log.debug('Saving Hugging Face checkpoint to disk') new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: From 4ec555a7999ed19633827e2df0e886c744a147d2 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Thu, 18 Apr 2024 19:47:21 -0700 Subject: [PATCH 2/2] pc --- llmfoundry/callbacks/hf_checkpointer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 3716a70f05..f899206add 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -288,7 +288,6 @@ def transform_model_and_tokenizer( Returns: Tuple[PreTrainedModel, PreTrainedTokenizerBase]: The transformed model and tokenizer. """ - return model, tokenizer def _save_checkpoint(self, state: State, logger: Logger):