diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 8f37b39306..e18e611ca6 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -19,11 +19,21 @@ log = logging.getLogger(__name__) -def _resolve_ffn_hidden_and_exp_ratio( +def resolve_ffn_hidden_size( d_model: int, expansion_ratio: Union[int, float], ffn_hidden_size: Optional[int] = None, -) -> tuple[Union[int, float], int]: +) -> int: + """Resolve the hidden size of the feed-forward network. + + Args: + d_model (int): The dimension of the input and output of the feed-forward network. + expansion_ratio (Union[int, float]): The expansion ratio of the feed-forward network. + ffn_hidden_size (Optional[int]): The hidden size of the feed-forward network. + + Returns: + int: The hidden size of the feed-forward network. + """ if ffn_hidden_size is not None: log.info( f'`expansion_ratio` (={expansion_ratio}) ignored when `ffn_hidden_size` (={ffn_hidden_size}) is specified.' @@ -32,9 +42,9 @@ def _resolve_ffn_hidden_and_exp_ratio( ffn_hidden_size = int(d_model * expansion_ratio) if ffn_hidden_size != d_model * expansion_ratio: raise ValueError( - f'`d_model * expansion_ratio` ({ffn_hidden_size}) must be an integer.' + f'`d_model * expansion_ratio` must be an integer ({d_model=}; {expansion_ratio=}; {d_model * expansion_ratio=}).' ) - return expansion_ratio, ffn_hidden_size + return ffn_hidden_size class MPTMLP(nn.Module): @@ -49,8 +59,8 @@ def __init__( bias: bool = True, ): super().__init__() - expansion_ratio, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( - d_model, expansion_ratio, ffn_hidden_size) + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) self.fc_kwargs: dict[str, Any] = { 'bias': bias, } @@ -138,8 +148,8 @@ def build_ffn( ) elif ffn_type == 'te_ln_mlp': assert te is not None - _, ffn_hidden_size = _resolve_ffn_hidden_and_exp_ratio( - d_model, expansion_ratio, ffn_hidden_size) + ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, + ffn_hidden_size) return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index b9b4929ad0..2ecc726aa3 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -70,7 +70,7 @@ def __init__( d_model (int): The size of the embedding dimension of the model. n_heads (int): The number of attention heads. n_layers (int): The number of layers in the model. - expansion_ratio (int, float): The ratio of the up/down scale in the ffn. + expansion_ratio (Union[int, float]): The ratio of the up/down scale in the ffn. max_seq_len (int): The maximum sequence length of the model. vocab_size (int): The size of the vocabulary. resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 13fe50d5cb..6d48d115fd 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -514,14 +514,21 @@ def test_opt_wrapping(): @pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys()) @pytest.mark.parametrize('no_bias', [False, True]) @pytest.mark.parametrize('tie_word_embeddings', [True, False]) -def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): +@pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [ + (2, None), + (1.231, None), + (2, 128), + (2, 256), +]) +def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, + expansion_ratio: Union[int, float], ffn_hidden_size: int): # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', d_model=128, n_heads=4, n_layers=2, - expansion_ratio=2, + expansion_ratio=expansion_ratio, max_seq_len=2048, emb_pdrop=0.1, resid_pdrop=0.2, @@ -531,13 +538,24 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): norm_type=norm_type, no_bias=no_bias, tie_word_embeddings=tie_word_embeddings, + ffn_config={ + 'ffn_type': 'mptmlp', + 'ffn_hidden_size': ffn_hidden_size, + }, ) + if hf_config.d_model * hf_config.expansion_ratio != int( + hf_config.d_model * hf_config.expansion_ratio): + pytest.xfail('d_model * expansion_ratio must be an integer.') + mpt = MPTForCausalLM(hf_config) assert mpt.config.d_model == 128 assert mpt.config.n_heads == 4 assert mpt.config.n_layers == 2 - assert mpt.config.expansion_ratio == 2 + if ffn_hidden_size is None: + assert mpt.config.expansion_ratio == expansion_ratio + else: + assert mpt.config.ffn_config['ffn_hidden_size'] == ffn_hidden_size assert mpt.config.max_seq_len == 2048 assert mpt.transformer.wte.weight.shape == torch.Size( @@ -551,21 +569,19 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert len(mpt.transformer.blocks) == 2 d_model = hf_config.d_model + if ffn_hidden_size is None: + ffn_hidden_size = int(hf_config.d_model * hf_config.expansion_ratio) for block in mpt.transformer.blocks: assert isinstance(block, MPTBlock) assert block.norm_1.weight.shape == torch.Size([d_model]) assert block.norm_2 is not None assert block.norm_2.weight.shape == torch.Size([d_model]) assert isinstance(block.ffn.up_proj, nn.Linear) - assert block.ffn.up_proj.weight.shape == torch.Size([ - int(hf_config.d_model * hf_config.expansion_ratio), - hf_config.d_model - ]) + assert block.ffn.up_proj.weight.shape == torch.Size( + [ffn_hidden_size, hf_config.d_model]) assert isinstance(block.ffn.down_proj, nn.Linear) - assert block.ffn.down_proj.weight.shape == torch.Size([ - hf_config.d_model, - int(hf_config.d_model * hf_config.expansion_ratio) - ]) + assert block.ffn.down_proj.weight.shape == torch.Size( + [hf_config.d_model, ffn_hidden_size]) assert block.resid_attn_dropout.p == 0.2 assert block.resid_ffn_dropout.p == 0.2