Skip to content

Commit

Permalink
Pass FC type along for all FFN types (#1196)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored May 11, 2024
1 parent 994209c commit eef4872
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
23 changes: 15 additions & 8 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""A HuggingFace-style model configuration."""

import copy
import warnings
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -55,15 +56,15 @@ def __init__(
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
learned_pos_emb: bool = True,
attn_config: Dict = attn_config_defaults,
ffn_config: Dict = ffn_config_defaults,
attn_config: Optional[Dict] = None,
ffn_config: Optional[Dict] = None,
init_device: str = 'cpu',
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = False,
embedding_fraction: float = 1.0,
norm_type: str = 'low_precision_layernorm',
use_cache: bool = False,
init_config: Dict = init_config_defaults,
init_config: Optional[Dict] = None,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
use_pad_tok_in_ffn: bool = True,
Expand Down Expand Up @@ -147,15 +148,21 @@ def __init__(
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.learned_pos_emb = learned_pos_emb
self.attn_config = attn_config
self.ffn_config = ffn_config
self.attn_config = attn_config if attn_config is not None else copy.deepcopy(
attn_config_defaults,
)
self.ffn_config = ffn_config if ffn_config is not None else copy.deepcopy(
ffn_config_defaults,
)
self.init_device = init_device
self.logit_scale = logit_scale
self.no_bias = no_bias
self.embedding_fraction = embedding_fraction
self.norm_type = norm_type
self.use_cache = use_cache
self.init_config = init_config
self.init_config = init_config if init_config is not None else copy.deepcopy(
init_config_defaults,
)
self.fc_type = fc_type
self.use_pad_tok_in_ffn = use_pad_tok_in_ffn

Expand Down Expand Up @@ -306,14 +313,14 @@ def _validate_config(self) -> None:
+ 'pip install flash-attn==1.0.6 --no-build-isolation \n' +
'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156',
)

self.ffn_config['fc_type'] = self.fc_type
if self.ffn_config['ffn_type'] == 'mptgeglu':
raise ValueError(
'API CHANGE: `ffn_type=="mptgeglu"` changed to `ffn_type=="mptglu"`. '
+
'See [#829](https://github.com/mosaicml/llm-foundry/pull/829) for details.',
)
elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']:
self.ffn_config['fc_type'] = self.fc_type
elif self.ffn_config['ffn_type'] in ffns_with_megablocks:
self.ffn_config['return_bias'] = False
elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
Expand Down
31 changes: 31 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,37 @@ def test_mpt_creation(
assert block.resid_ffn_dropout.p == 0.2


@pytest.mark.gpu
def test_mb_mpt_creation():
# Test that the config constructs the model as expected.
hf_config = MPTConfig(
init_device='cpu',
d_model=128,
n_heads=4,
n_layers=2,
expansion_ratio=2,
max_seq_len=2048,
emb_pdrop=0.1,
resid_pdrop=0.2,
attn_config={
'attn_impl': 'torch',
},
norm_type='low_precision_layernorm',
no_bias=True,
tie_word_embeddings=False,
ffn_config={
'ffn_type': 'mb_moe',
'ffn_hidden_size': 1024,
'ffn_act_fn': {
'name': 'gelu',
},
'moe_world_size': 1,
},
)

_ = MPTForCausalLM(hf_config)


@pytest.mark.gpu
@pytest.mark.parametrize('attention_impl', ['flash', 'torch'])
@pytest.mark.parametrize(
Expand Down

0 comments on commit eef4872

Please sign in to comment.