diff --git a/nemo/collections/llm/gpt/data/squad.py b/nemo/collections/llm/gpt/data/squad.py index 77d48da98a0e..0ce7c52a14bd 100644 --- a/nemo/collections/llm/gpt/data/squad.py +++ b/nemo/collections/llm/gpt/data/squad.py @@ -124,3 +124,6 @@ def _preprocess_and_split_data( shutil.rmtree(p) elif '.jsonl' not in str(p.name): p.unlink() + + def reconfigure_limit_batches(self): + return diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index 6256b67515ee..96edeadd439a 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -47,6 +47,8 @@ class MixtralConfig8x7B(GPTConfig): # rotary rotary_percent: float = 0.5 rotary_base: float = 10000 + bf16: bool = True + params_dtype: torch.dtype = torch.bfloat16 class MixtralModel(GPTModel): @@ -70,7 +72,7 @@ def init(self) -> MixtralModel: def apply(self, output_path: Path) -> Path: from transformers import MixtralForCausalLM - source = MixtralForCausalLM.from_pretrained(str(self)) + source = MixtralForCausalLM.from_pretrained(str(self), torch_dtype='auto', use_safetensors=True) target = self.init() trainer = self.nemo_setup(target) self.convert_state(source, target) @@ -109,6 +111,7 @@ def config(self) -> MixtralConfig8x7B: config = HfMixtralConfig.from_pretrained(str(self)) return MixtralConfig8x7B( + bf16=getattr(config, "torch_dtype", None) == torch.bfloat16, activation_func=F.silu, # network num_layers=config.num_hidden_layers, @@ -132,6 +135,10 @@ def config(self) -> MixtralConfig8x7B: gated_linear_unit=True, # Vocab make_vocab_size_divisible_by=128, + # CPU init + use_cpu_initialization=True, + perform_initialization=False, + params_dtype=getattr(config, "torch_dtype", torch.bfloat16), ) diff --git a/nemo/lightning/io/state.py b/nemo/lightning/io/state.py index b69fed9d0f4f..9fd81a960358 100644 --- a/nemo/lightning/io/state.py +++ b/nemo/lightning/io/state.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload import numpy as np +import torch from torch import nn SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) @@ -19,11 +20,12 @@ class TransformCTX: target_state: dict +@torch.no_grad def apply_transforms( source: nn.Module, target: TargetModuleT, mapping: Dict[str, str], - transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None, + transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [], ) -> TargetModuleT: """ Applies a series of transformations to adapt the state dictionary of a source module to @@ -101,9 +103,8 @@ def scale_weights(ctx): for key, val in mapping.items(): ctx = StateDictTransform(key, val)(ctx) - if transforms: - for transform in transforms: - ctx = transform(ctx) + for transform in transforms: + ctx = transform(ctx) _params: Dict[str, nn.Parameter] = {} for name, param in _target.named_parameters(): @@ -144,9 +145,9 @@ def scale_weights(ctx): _module.register_buffer(_key, val) - keys = [name for name in list(target_state.keys()) if not name.endswith("_extra_state")] + keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys())) if len(keys) != 0: - raise RuntimeError(f"Additional keys: {target_state.keys()} in checkpoint but not in model.") + raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.") # TODO: Is this correct? # for key in target.state_dict(): @@ -165,7 +166,7 @@ def scale_weights(ctx): def _default_transform(inp): - return inp.float() + return inp class StateDictTransform(Generic[F]): @@ -324,7 +325,7 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray: regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$") wildcard_matches = [[] for _ in range(pattern.count("*"))] - for key in keys: + for key in filter(lambda x: x is not None, keys): match = regex_pattern.match(key) if match: for i, group in enumerate(match.groups()): @@ -342,7 +343,7 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray: output_array = np.empty(shape, dtype=object) # Populate the array with the keys, now that we have the correct shape and ordering - for key in keys: + for key in filter(lambda x: x is not None, keys): match = regex_pattern.match(key) if match: # Convert match groups to indices based on their position in wildcard_matches