Skip to content

Commit

Permalink
Add linear layer and ffn config to enable TransformerEngine layers (w…
Browse files Browse the repository at this point in the history
…ith FP8) (#432)

* adding te Linear for fp8 support

* integ warning / config auto updt

* req updts

* updt FC_CLASS_REGISTRY to enable init

* updt fc.py

* skip running install in readme

* Update attention.py

* add te layers and linear / ffn config

* cleanup

* add te linear and fp8 instructions

* install help

* updt imports

* updt yamls

* add ffn file

* leave ex_ratio alone

* updt warning

* lint

* add default PE

* updt tutorial

* add tutorial warning

* cleanup

* updt yaml checks

* undo line erase

* diable torch._dynamo for te.LayerNormMLP

* updt te version

* fix conditional
  • Loading branch information
vchiley authored Jul 17, 2023
1 parent 275fa56 commit 340a566
Show file tree
Hide file tree
Showing 13 changed files with 328 additions and 82 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ pip install cmake packaging torch # setup.py requires these be installed
pip install -e ".[gpu]" # or pip install -e . if no NVIDIA GPU
```

### TransformerEngine and amp_fp8 support
NVIDIA H100 GPUs have FP8 support; this additionally requires the following installations:
<!--pytest.mark.skip-->
```bash
pip install flash-attn==1.0.7 --no-build-isolation
pip install git+https://github.com/NVIDIA/[email protected]
```

See [here](https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#TransformerEngine-and-amp_fp8-support) for more details on enabling TransformerEngine layers and amp_fp8.

### AMD (BETA support)

Expand Down
22 changes: 22 additions & 0 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,28 @@ lora:
### How do I deploy with ONNX/FasterTransformer?
- Check out the `scripts/inference` directory for instructions and scripts.

### TransformerEngine and amp_fp8 support
Once [installed](https://github.com/mosaicml/llm-foundry/tree/main#TransformerEngine-and-amp_fp8-support), if you are using an H100, you can use fp8 with te layers by setting eg:
<!--pytest.mark.skip-->
```yaml
precision: amp_fp8
model:
fc_type: te
```
in the training yaml.

Setting
<!--pytest.mark.skip-->
```yaml
model:
ffn_config_defaults:
ffn_type: te_ln_mlp
```
enables [TransformerEngine's LayerNormMLP](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.LayerNormMLP) layer which enables sequence parallelism if configured correctly.

WARNING: `state_dicts` generated with `ffn_type: te_ln_mlp` will NOT directly map to `state_dicts` generated using the default network configurations. We do not have control over how `te.LayerNormMLP` is implemented and therefore cannot reasily reconcile it with the default implementation (or any other implementation).

### How expensive is it to build LLMs?
- Check out our blog post [GPT3-Quality for <$500k](https://www.mosaicml.com/blog/gpt-3-quality-for-500k) for guidance on LLM training times and costs.

Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTMLP, MPTBlock
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel,
Expand All @@ -38,8 +40,10 @@
'build_finetuning_dataloader',
'MixtureOfDenoisersCollator',
'Seq2SeqFinetuningCollator',
'MPTMLP',
'MPTBlock',
'FFN_CLASS_REGISTRY',
'MPTMLP',
'build_ffn',
'MPTConfig',
'MPTPreTrainedModel',
'MPTModel',
Expand Down
7 changes: 6 additions & 1 deletion llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTMLP, MPTBlock
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm

__all__ = [
Expand All @@ -23,5 +25,8 @@
'MPTBlock',
'NORM_CLASS_REGISTRY',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
]
53 changes: 37 additions & 16 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from packaging import version
from torch import nn

from llmfoundry.models.layers.norm import LPLayerNorm
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
Expand Down Expand Up @@ -342,7 +343,8 @@ def __init__(
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
):
Expand All @@ -359,15 +361,22 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
fc_kwargs = {}
if fc_type != 'te':
fc_kwargs['device'] = device
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
self.d_model,
3 * self.d_model,
**fc_kwargs,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = (d_model, 2 * d_model)
self.Wqkv._fused = (0, fuse_splits) # type: ignore

if self.qk_ln:
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(self.d_model, device=device)
self.k_ln = layernorm_class(self.d_model, device=device)
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
self.q_ln = norm_class(self.d_model, device=device)
self.k_ln = norm_class(self.d_model, device=device)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
Expand All @@ -391,7 +400,11 @@ def __init__(
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')

self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model,
**fc_kwargs,
)
self.out_proj._is_residual = True # type: ignore

def forward(
Expand All @@ -406,7 +419,7 @@ def forward(
qkv = self.Wqkv(x)

if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

query, key, value = qkv.chunk(3, dim=2)

Expand Down Expand Up @@ -452,7 +465,8 @@ def __init__(
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
verbose: int = 0,
device: Optional[str] = None,
):
Expand All @@ -470,23 +484,26 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.head_dim)
self.attn_dropout_p = attn_pdrop

fc_kwargs = {}
if fc_type != 'te':
fc_kwargs['device'] = device
# NOTE: if we ever want to make attn TensorParallel, I'm pretty sure we'll
# want to split Wqkv into Wq and Wkv where Wq can be TensorParallel but
# Wkv shouldn't be TensorParallel
# - vchiley
self.Wqkv = nn.Linear(
self.Wqkv = FC_CLASS_REGISTRY[fc_type](
d_model,
d_model + 2 * self.head_dim,
device=device,
**fc_kwargs,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = (d_model, d_model + self.head_dim)
self.Wqkv._fused = (0, fuse_splits) # type: ignore

if self.qk_ln:
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(d_model, device=device)
self.k_ln = layernorm_class(self.head_dim, device=device)
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
self.q_ln = norm_class(d_model, device=device)
self.k_ln = norm_class(self.head_dim, device=device)

if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
Expand All @@ -510,7 +527,11 @@ def __init__(
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')

self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj = FC_CLASS_REGISTRY[fc_type](
self.d_model,
self.d_model,
**fc_kwargs,
)
self.out_proj._is_residual = True # type: ignore

def forward(
Expand All @@ -525,7 +546,7 @@ def forward(
qkv = self.Wqkv(x)

if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)

query, key, value = qkv.split(
[self.d_model, self.head_dim, self.head_dim], dim=2)
Expand Down
87 changes: 41 additions & 46 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,40 @@
import torch.nn as nn

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


class MPTMLP(nn.Module):

def __init__(self,
d_model: int,
expansion_ratio: int,
device: Optional[str] = None):
super().__init__()
self.up_proj = nn.Linear(d_model,
expansion_ratio * d_model,
device=device)
self.act = nn.GELU(approximate='none')
self.down_proj = nn.Linear(expansion_ratio * d_model,
d_model,
device=device)
self.down_proj._is_residual = True # type: ignore

def forward(self, x):
return self.down_proj(self.act(self.up_proj(x)))


class MPTBlock(nn.Module):

def __init__(
self,
d_model: int,
n_heads: int,
expansion_ratio: int,
attn_config: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
},
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
verbose: int = 0,
device: Optional[str] = None,
**kwargs):
self,
d_model: int,
n_heads: int,
expansion_ratio: int,
attn_config: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
},
ffn_config: Dict = {
'ffn_type': 'mptmlp',
},
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
verbose: int = 0,
fc_type: str = 'torch',
device: Optional[str] = None,
**kwargs,
):
del kwargs # unused, just to capture any extra args from the config
super().__init__()

Expand All @@ -64,21 +51,27 @@ def __init__(

self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(
d_model=d_model,
n_heads=n_heads,
attn_impl=attn_config['attn_impl'],
clip_qkv=attn_config['clip_qkv'],
qk_ln=attn_config['qk_ln'],
softmax_scale=attn_config['softmax_scale'],
attn_pdrop=attn_config['attn_pdrop'],
d_model=d_model,
n_heads=n_heads,
norm_type=norm_type,
fc_type=fc_type,
verbose=verbose,
device=device,
)
self.norm_2 = norm_class(d_model, device=device)
self.ffn = MPTMLP(
self.norm_2 = None
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
False):
self.norm_2 = norm_class(d_model, device=device)
self.ffn = build_ffn(
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
**ffn_config,
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
Expand All @@ -100,7 +93,9 @@ def forward(
is_causal=is_causal,
)
x = x + self.resid_attn_dropout(b)
m = self.norm_2(x)
m = x
if self.norm_2 is not None:
m = self.norm_2(x)
n = self.ffn(m)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value
14 changes: 14 additions & 0 deletions llmfoundry/models/layers/fc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from torch import nn

FC_CLASS_REGISTRY = {
'torch': nn.Linear,
}

try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY['te'] = te.Linear
except:
pass
Loading

0 comments on commit 340a566

Please sign in to comment.