Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Norms registry #1080

Merged
merged 16 commits into from
Apr 5, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ dependencies = [
"llm-foundry",
]

[project.entry-points."llm_foundry.loggers"]
[project.entry-points."llmfoundry_loggers"]
my_logger = "foundry_registry.loggers:MyLogger"
```

Expand Down
20 changes: 20 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 MosaicML LLM Foundry authors
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
# SPDX-License-Identifier: Apache-2.0

from typing import Type

import torch

from llmfoundry.utils.registry_utils import create_registry

# Layers
_norm_description = """The norms registry is used to register classes that implement normalization layers."""
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)

__all__ = [
'norms',
]
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm
from llmfoundry.models.layers.norm import LPLayerNorm

__all__ = [
'scaled_multihead_dot_product_attention',
Expand All @@ -23,7 +23,6 @@
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'NORM_CLASS_REGISTRY',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
Expand Down
15 changes: 11 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -419,12 +419,19 @@ def __init__(
self.Wqkv._fused = (0, fuse_splits)

if self.qk_ln or self.qk_gn:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
norm_size = self.head_dim if qk_gn else d_model
self.q_ln = norm_class(norm_size, device=device)
self.q_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = norm_class(norm_size, device=device)
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
Expand Down
15 changes: 11 additions & 4 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -72,7 +72,6 @@ def __init__(
del kwargs # unused, just to capture any extra args from the config
super().__init__()

norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

Expand All @@ -88,7 +87,11 @@ def __init__(
if k not in args_to_exclude_in_attn_class
}

self.norm_1 = norm_class(d_model, device=device)
self.norm_1 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
Expand All @@ -100,7 +103,11 @@ def __init__(
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
self.norm_2 = norm_class(d_model, device=device)
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)
self.ffn = build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
Expand Down
25 changes: 25 additions & 0 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.utils.registry_utils import construct_from_registry


def build_norm(
name: str,
normalized_shape: Union[int, List[int], torch.Size],
device: Optional[str] = None,
):
kwargs = {
'normalized_shape': normalized_shape,
'device': device,
}

return construct_from_registry(name=name,
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)
19 changes: 9 additions & 10 deletions llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Optional, Type, Union
from typing import List, Optional, Union

import torch

from llmfoundry.layers_registry import norms

norms.register(name='layernorm', func=torch.nn.LayerNorm)


def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
if torch.is_autocast_enabled():
Expand All @@ -18,6 +22,7 @@ def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
return tensor


@norms.register_class('low_precision_layernorm')
class LPLayerNorm(torch.nn.LayerNorm):

def __init__(
Expand Down Expand Up @@ -62,6 +67,7 @@ def rms_norm(x: torch.Tensor,
return output


@norms.register_class('rmsnorm')
class RMSNorm(torch.nn.Module):

def __init__(
Expand All @@ -84,6 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)


@norms.register_class('low_precision_rmsnorm')
class LPRMSNorm(RMSNorm):

def __init__(
Expand Down Expand Up @@ -111,6 +118,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.eps).to(dtype=x.dtype)


@norms.register_class('triton_rmsnorm')
class TritonRMSNorm(torch.nn.Module):

def __init__(
Expand Down Expand Up @@ -150,12 +158,3 @@ def forward(self, x: torch.Tensor):
prenorm=False,
residual_in_fp32=False,
)


NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
'layernorm': torch.nn.LayerNorm,
'low_precision_layernorm': LPLayerNorm,
'rmsnorm': RMSNorm,
'low_precision_rmsnorm': LPRMSNorm,
'triton_rmsnorm': TritonRMSNorm,
}
3 changes: 3 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
Expand Down
14 changes: 9 additions & 5 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from composer.utils import dist

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
Expand All @@ -42,11 +41,13 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.ffn import build_ffn as build_ffn
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.mpt.configuration_mpt import MPTConfig

# NOTE: All utils are imported directly even if unused so that
Expand Down Expand Up @@ -297,12 +298,11 @@ def __init__(self, config: MPTConfig):
else:
config.init_device = 'meta'

if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
if config.norm_type.lower() not in norms.get_all():
norm_options = ' | '.join(norms.get_all())
raise NotImplementedError(
f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).'
)
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]

# CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
# both report this helping with stabilizing training
Expand All @@ -329,7 +329,11 @@ def __init__(self, config: MPTConfig):
block.max_block_idx = config.n_layers - 1
pass_on_block_idx(block)

self.norm_f = norm_class(config.d_model, device=config.init_device)
self.norm_f = build_norm(
name=config.norm_type.lower(),
normalized_shape=config.d_model,
device=config.init_device,
)

self.rope = config.attn_config['rope']
self.rope_impl = None
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def pass_on_block_idx(parent: torch.nn.Module):
Expand All @@ -29,12 +29,12 @@ def get_act_ckpt_module(mod_name: str) -> Any:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in NORM_CLASS_REGISTRY:
mod_type = NORM_CLASS_REGISTRY[mod_name]
elif mod_name in norms:
mod_type = norms.get(mod_name)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
list(norms.get_all()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
Expand Down
5 changes: 3 additions & 2 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
from torch import nn

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -129,7 +129,8 @@ def generic_param_init_fn_(

emb_init_fn_(module.weight)

elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
elif isinstance(module,
tuple(set([norms.get(name) for name in norms.get_all()]))):
# Norm
if hasattr(module, 'weight') and isinstance(module.weight,
torch.Tensor):
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import norms
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = (
Expand Down Expand Up @@ -119,4 +120,5 @@
'models',
'metrics',
'dataloaders',
'norms',
]
7 changes: 7 additions & 0 deletions llmfoundry/utils/registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
__all__ = ['TypedRegistry', 'create_registry', 'construct_from_registry']

T = TypeVar('T')
TypeBoundT = TypeVar('TypeBoundT', bound=Type)


class TypedRegistry(catalogue.Registry, Generic[T]):
Expand All @@ -36,6 +37,12 @@ def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]:
def register(self, name: str, *, func: Optional[T] = None) -> T:
return super().register(name, func=func)

def register_class(self,
name: str,
*,
func: Optional[TypeBoundT] = None) -> TypeBoundT:
return super().register(name, func=func)

def get(self, name: str) -> T:
return super().get(name)

Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from transformers.models.bloom.modeling_bloom import build_alibi_tensor

from llmfoundry import ComposerHFCausalLM
from llmfoundry.layers_registry import norms
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias
from llmfoundry.models.layers import build_alibi_bias
from llmfoundry.models.layers.attention import (check_alibi_support,
is_flash_v2_installed)
from llmfoundry.models.layers.blocks import MPTBlock
Expand Down Expand Up @@ -682,7 +683,7 @@ def test_lora_id():
assert isinstance(model.model, peft.PeftModelForCausalLM)


@pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys())
@pytest.mark.parametrize('norm_type', norms.get_all())
@pytest.mark.parametrize('no_bias', [False, True])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
@pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [
Expand Down
Loading
Loading