Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Akoumparouli/mixtral fixes for r2.0.0rc1 #9911

Merged
merged 4 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo/collections/llm/gpt/data/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ def _preprocess_and_split_data(
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()

def reconfigure_limit_batches(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this now that a fix has been merged?

return
9 changes: 8 additions & 1 deletion nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class MixtralConfig8x7B(GPTConfig):
# rotary
rotary_percent: float = 0.5
rotary_base: float = 10000
bf16: bool = True
params_dtype: torch.dtype = torch.bfloat16


class MixtralModel(GPTModel):
Expand All @@ -70,7 +72,7 @@ def init(self) -> MixtralModel:
def apply(self, output_path: Path) -> Path:
from transformers import MixtralForCausalLM

source = MixtralForCausalLM.from_pretrained(str(self))
source = MixtralForCausalLM.from_pretrained(str(self), torch_dtype='auto', use_safetensors=True)
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
Expand Down Expand Up @@ -109,6 +111,7 @@ def config(self) -> MixtralConfig8x7B:

config = HfMixtralConfig.from_pretrained(str(self))
return MixtralConfig8x7B(
bf16=getattr(config, "torch_dtype", None) == torch.bfloat16,
activation_func=F.silu,
# network
num_layers=config.num_hidden_layers,
Expand All @@ -132,6 +135,10 @@ def config(self) -> MixtralConfig8x7B:
gated_linear_unit=True,
# Vocab
make_vocab_size_divisible_by=128,
# CPU init
use_cpu_initialization=True,
perform_initialization=False,
params_dtype=getattr(config, "torch_dtype", torch.bfloat16),
)


Expand Down
19 changes: 10 additions & 9 deletions nemo/lightning/io/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload

import numpy as np
import torch
from torch import nn

SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module)
Expand All @@ -19,11 +20,12 @@ class TransformCTX:
target_state: dict


@torch.no_grad
def apply_transforms(
source: nn.Module,
target: TargetModuleT,
mapping: Dict[str, str],
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
) -> TargetModuleT:
"""
Applies a series of transformations to adapt the state dictionary of a source module to
Expand Down Expand Up @@ -101,9 +103,8 @@ def scale_weights(ctx):
for key, val in mapping.items():
ctx = StateDictTransform(key, val)(ctx)

if transforms:
for transform in transforms:
ctx = transform(ctx)
for transform in transforms:
ctx = transform(ctx)

_params: Dict[str, nn.Parameter] = {}
for name, param in _target.named_parameters():
Expand Down Expand Up @@ -144,9 +145,9 @@ def scale_weights(ctx):

_module.register_buffer(_key, val)

keys = [name for name in list(target_state.keys()) if not name.endswith("_extra_state")]
keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys()))
if len(keys) != 0:
raise RuntimeError(f"Additional keys: {target_state.keys()} in checkpoint but not in model.")
raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.")

# TODO: Is this correct?
# for key in target.state_dict():
Expand All @@ -165,7 +166,7 @@ def scale_weights(ctx):


def _default_transform(inp):
return inp.float()
return inp


class StateDictTransform(Generic[F]):
Expand Down Expand Up @@ -324,7 +325,7 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray:
regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$")
wildcard_matches = [[] for _ in range(pattern.count("*"))]

for key in keys:
for key in filter(lambda x: x is not None, keys):
match = regex_pattern.match(key)
if match:
for i, group in enumerate(match.groups()):
Expand All @@ -342,7 +343,7 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray:
output_array = np.empty(shape, dtype=object)

# Populate the array with the keys, now that we have the correct shape and ordering
for key in keys:
for key in filter(lambda x: x is not None, keys):
match = regex_pattern.match(key)
if match:
# Convert match groups to indices based on their position in wildcard_matches
Expand Down
Loading