Skip to content

Commit

Permalink
Norms registry (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Apr 5, 2024
1 parent 60a1ab4 commit b81897a
Show file tree
Hide file tree
Showing 16 changed files with 121 additions and 43 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ dependencies = [
"llm-foundry",
]

[project.entry-points."llm_foundry.loggers"]
[project.entry-points."llmfoundry_loggers"]
my_logger = "foundry_registry.loggers:MyLogger"
```

Expand Down
20 changes: 20 additions & 0 deletions llmfoundry/layers_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Type

import torch

from llmfoundry.utils.registry_utils import create_registry

# Layers
_norm_description = """The norms registry is used to register classes that implement normalization layers."""
norms = create_registry('llmfoundry',
'norms',
generic_type=Type[torch.nn.Module],
entry_points=True,
description=_norm_description)

__all__ = [
'norms',
]
3 changes: 1 addition & 2 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm
from llmfoundry.models.layers.norm import LPLayerNorm

__all__ = [
'scaled_multihead_dot_product_attention',
Expand All @@ -23,7 +23,6 @@
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'NORM_CLASS_REGISTRY',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
Expand Down
15 changes: 11 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import nn

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm


def is_flash_v2_installed(v2_version: str = '2.0.0'):
Expand Down Expand Up @@ -419,12 +419,19 @@ def __init__(
self.Wqkv._fused = (0, fuse_splits)

if self.qk_ln or self.qk_gn:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
norm_size = self.head_dim if qk_gn else d_model
self.q_ln = norm_class(norm_size, device=device)
self.q_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)
if qk_ln:
norm_size = self.head_dim * kv_n_heads
self.k_ln = norm_class(norm_size, device=device)
self.k_ln = build_norm(
name=norm_type.lower(),
normalized_shape=norm_size,
device=device,
)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
Expand Down
15 changes: 11 additions & 4 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.layers.layer_builders import build_norm

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -72,7 +72,6 @@ def __init__(
del kwargs # unused, just to capture any extra args from the config
super().__init__()

norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
assert isinstance(attn_config['attn_type'], str)
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]

Expand All @@ -88,7 +87,11 @@ def __init__(
if k not in args_to_exclude_in_attn_class
}

self.norm_1 = norm_class(d_model, device=device)
self.norm_1 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
Expand All @@ -100,7 +103,11 @@ def __init__(
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
self.norm_2 = norm_class(d_model, device=device)
self.norm_2 = build_norm(
name=norm_type.lower(),
normalized_shape=d_model,
device=device,
)
self.ffn = build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
Expand Down
25 changes: 25 additions & 0 deletions llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional, Union

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.utils.registry_utils import construct_from_registry


def build_norm(
name: str,
normalized_shape: Union[int, List[int], torch.Size],
device: Optional[str] = None,
):
kwargs = {
'normalized_shape': normalized_shape,
'device': device,
}

return construct_from_registry(name=name,
registry=norms,
pre_validation_function=torch.nn.Module,
kwargs=kwargs)
19 changes: 9 additions & 10 deletions llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Optional, Type, Union
from typing import List, Optional, Union

import torch

from llmfoundry.layers_registry import norms

norms.register(name='layernorm', func=torch.nn.LayerNorm)


def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
if torch.is_autocast_enabled():
Expand All @@ -18,6 +22,7 @@ def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor:
return tensor


@norms.register_class('low_precision_layernorm')
class LPLayerNorm(torch.nn.LayerNorm):

def __init__(
Expand Down Expand Up @@ -62,6 +67,7 @@ def rms_norm(x: torch.Tensor,
return output


@norms.register_class('rmsnorm')
class RMSNorm(torch.nn.Module):

def __init__(
Expand All @@ -84,6 +90,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)


@norms.register_class('low_precision_rmsnorm')
class LPRMSNorm(RMSNorm):

def __init__(
Expand Down Expand Up @@ -111,6 +118,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.eps).to(dtype=x.dtype)


@norms.register_class('triton_rmsnorm')
class TritonRMSNorm(torch.nn.Module):

def __init__(
Expand Down Expand Up @@ -150,12 +158,3 @@ def forward(self, x: torch.Tensor):
prenorm=False,
residual_in_fp32=False,
)


NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
'layernorm': torch.nn.LayerNorm,
'low_precision_layernorm': LPLayerNorm,
'rmsnorm': RMSNorm,
'low_precision_rmsnorm': LPRMSNorm,
'triton_rmsnorm': TritonRMSNorm,
}
3 changes: 3 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
Expand Down
14 changes: 9 additions & 5 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from composer.utils import dist

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

if is_flash_v2_installed():
try: # This try...except is needed because transformers requires it despite the 'if' statement above
Expand All @@ -42,11 +41,13 @@
from transformers.models.llama.modeling_llama import \
LlamaRotaryEmbedding as HFRotaryEmbedding

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import (attn_bias_shape,
build_attn_bias, gen_slopes)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.ffn import build_ffn as build_ffn
from llmfoundry.models.layers.layer_builders import build_norm
from llmfoundry.models.mpt.configuration_mpt import MPTConfig

# NOTE: All utils are imported directly even if unused so that
Expand Down Expand Up @@ -297,12 +298,11 @@ def __init__(self, config: MPTConfig):
else:
config.init_device = 'meta'

if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
if config.norm_type.lower() not in norms.get_all():
norm_options = ' | '.join(norms.get_all())
raise NotImplementedError(
f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).'
)
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]

# CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
# both report this helping with stabilizing training
Expand All @@ -329,7 +329,11 @@ def __init__(self, config: MPTConfig):
block.max_block_idx = config.n_layers - 1
pass_on_block_idx(block)

self.norm_f = norm_class(config.d_model, device=config.init_device)
self.norm_f = build_norm(
name=config.norm_type.lower(),
normalized_shape=config.d_model,
device=config.init_device,
)

self.rope = config.attn_config['rope']
self.rope_impl = None
Expand Down
8 changes: 4 additions & 4 deletions llmfoundry/models/utils/act_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import torch

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def pass_on_block_idx(parent: torch.nn.Module):
Expand All @@ -29,12 +29,12 @@ def get_act_ckpt_module(mod_name: str) -> Any:
mod_type = ATTN_CLASS_REGISTRY[mod_name]
elif mod_name in FFN_CLASS_REGISTRY:
mod_type = FFN_CLASS_REGISTRY[mod_name]
elif mod_name in NORM_CLASS_REGISTRY:
mod_type = NORM_CLASS_REGISTRY[mod_name]
elif mod_name in norms:
mod_type = norms.get(mod_name)
else:
msg = ', '.join(
list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock'])
list(norms.get_all()) + ['MPTBlock'])
raise ValueError(
f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
)
Expand Down
5 changes: 3 additions & 2 deletions llmfoundry/models/utils/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
from torch import nn

from llmfoundry.layers_registry import norms
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -129,7 +129,8 @@ def generic_param_init_fn_(

emb_init_fn_(module.weight)

elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
elif isinstance(module,
tuple(set([norms.get(name) for name in norms.get_all()]))):
# Norm
if hasattr(module, 'weight') and isinstance(module.weight,
torch.Tensor):
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.layers_registry import norms
from llmfoundry.utils.registry_utils import create_registry

_loggers_description = (
Expand Down Expand Up @@ -119,4 +120,5 @@
'models',
'metrics',
'dataloaders',
'norms',
]
7 changes: 7 additions & 0 deletions llmfoundry/utils/registry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
__all__ = ['TypedRegistry', 'create_registry', 'construct_from_registry']

T = TypeVar('T')
TypeBoundT = TypeVar('TypeBoundT', bound=Type)


class TypedRegistry(catalogue.Registry, Generic[T]):
Expand All @@ -36,6 +37,12 @@ def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]:
def register(self, name: str, *, func: Optional[T] = None) -> T:
return super().register(name, func=func)

def register_class(self,
name: str,
*,
func: Optional[TypeBoundT] = None) -> TypeBoundT:
return super().register(name, func=func)

def get(self, name: str) -> T:
return super().get(name)

Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from transformers.models.bloom.modeling_bloom import build_alibi_tensor

from llmfoundry import ComposerHFCausalLM
from llmfoundry.layers_registry import norms
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP
from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias
from llmfoundry.models.layers import build_alibi_bias
from llmfoundry.models.layers.attention import (check_alibi_support,
is_flash_v2_installed)
from llmfoundry.models.layers.blocks import MPTBlock
Expand Down Expand Up @@ -682,7 +683,7 @@ def test_lora_id():
assert isinstance(model.model, peft.PeftModelForCausalLM)


@pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys())
@pytest.mark.parametrize('norm_type', norms.get_all())
@pytest.mark.parametrize('no_bias', [False, True])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
@pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [
Expand Down
Loading

0 comments on commit b81897a

Please sign in to comment.