Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fc_type a dict to pass fc kwargs through #1201

Merged
merged 22 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
attention_implementations,
)
from llmfoundry.models.layers.layer_builders import build_fc, build_norm
from llmfoundry.models.utils.config_defaults import fc_type_defaults

__all__ = [
'scaled_multihead_dot_product_attention',
Expand Down Expand Up @@ -410,7 +411,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
Expand All @@ -429,6 +430,13 @@ def __init__(

self.head_dim = d_model // n_heads

# Usually, fc_type dict should be passed in through MPTBlock's __init__ function.
if fc_type is None:
fc_type = fc_type_defaults
fc_type['bias'] = bias
fc_type['device'] = device
fc_type_name = fc_type['name']

if self.kv_n_heads <= 0:
raise ValueError('kv_n_heads should be greater than zero.')

Expand All @@ -449,15 +457,11 @@ def __init__(
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop

fc_kwargs: dict[str, Any] = {
'bias': bias,
}
fc_kwargs['device'] = device
self.Wqkv = build_fc(
name=fc_type,
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
fc_kwargs=fc_kwargs,
fc_kwargs=fc_type,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = [
Expand All @@ -484,10 +488,10 @@ def __init__(
self.attn_fn = attention_implementations.get(self.attn_impl)

self.out_proj = build_fc(
name=fc_type,
name=fc_type_name,
in_features=self.d_model,
out_features=self.d_model,
fc_kwargs=fc_kwargs,
fc_kwargs=fc_type,
)
self.out_proj._is_residual = True

Expand Down Expand Up @@ -696,7 +700,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
Expand Down Expand Up @@ -737,7 +741,7 @@ def __init__(
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
bias: bool = True,
sliding_window_size: int = -1,
Expand Down
56 changes: 25 additions & 31 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
build_ffn,
build_norm,
)
from llmfoundry.models.utils.config_defaults import (
attn_config_defaults,
fc_type_defaults,
)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
Expand All @@ -25,32 +29,6 @@
'FusedNormAttentionNorm',
]

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'flash',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
'rope_theta': 10000,
'rope_impl': 'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
'xpos_scale_base': 512,
},
'rope_hf_config': {
'type': 'no_scaling',
'factor': 1.0,
},
}


class MPTBlock(nn.Module):

Expand All @@ -63,7 +41,7 @@ def __init__(
ffn_config: Optional[Dict] = None,
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
device: Optional[str] = None,
no_bias: bool = False,
use_pad_tok_in_ffn: bool = True,
Expand All @@ -73,15 +51,25 @@ def __init__(
attn_config = attn_config_defaults

if ffn_config is None:
ffn_config = {
self.ffn_config: dict[str, Any] = {
'ffn_type': 'mptmlp',
}
else:
self.ffn_config = ffn_config

if fc_type is None:
fc_type = fc_type_defaults
fc_type['bias'] = not no_bias
fc_type['device'] = device

self.ffn_config['fc_type'] = fc_type

self.fuse_norm_attn_norm = kwargs.get('fuse_norm_attn_norm', False)

del kwargs # unused, just to capture any extra args from the config
super().__init__()

ffn_type = ffn_config['ffn_type']
ffn_type = self.ffn_config['ffn_type']
ffn_has_norm = ffn_type in ffns_with_norm

if self.fuse_norm_attn_norm:
Expand Down Expand Up @@ -137,7 +125,7 @@ def __init__(
expansion_ratio=expansion_ratio,
device=device,
bias=not no_bias,
ffn_kwargs=ffn_config,
ffn_kwargs=self.ffn_config,
)

self.resid_attn_dropout = nn.Dropout(resid_pdrop)
Expand Down Expand Up @@ -240,7 +228,7 @@ def __init__(
args_to_exclude_in_attn_class: Set[str],
attn_config: Optional[Dict] = None,
ffn_has_norm: bool = False,
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
device: Optional[str] = None,
Expand All @@ -251,6 +239,12 @@ def __init__(
assert attn_config is not None
assert isinstance(attn_config['attn_type'], str)

# Usually, fc_type dict should be passed in through MPTBlock's __init__ function.
if fc_type is None:
fc_type = fc_type_defaults
fc_type['bias'] = not no_bias
fc_type['device'] = device

# Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
attn_config_subset_for_attn_class = {
k: v
Expand Down
37 changes: 21 additions & 16 deletions llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from llmfoundry.models.layers.dmoe import dMoE
from llmfoundry.models.layers.layer_builders import build_fc
from llmfoundry.models.utils.config_defaults import fc_type_defaults

try:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
self,
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
Expand All @@ -139,24 +140,27 @@ def __init__(
expansion_ratio,
ffn_hidden_size,
)
self.fc_kwargs: dict[str, Any] = {
'bias': bias,
}

self.fc_kwargs['device'] = device
# Usually, fc_type dict should be passed in through MPTBlock's __init__ function.
if fc_type is None:
fc_type = fc_type_defaults
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
fc_type['bias'] = bias
fc_type['device'] = device
self.fc_type = fc_type
self.fc_type_name = self.fc_type['name']
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

self.up_proj = build_fc(
name=fc_type,
name=self.fc_type_name,
in_features=d_model,
out_features=ffn_hidden_size,
fc_kwargs=self.fc_kwargs,
fc_kwargs=self.fc_type,
)
self.act = act_fn
self.down_proj = build_fc(
name=fc_type,
name=self.fc_type_name,
in_features=ffn_hidden_size,
out_features=d_model,
fc_kwargs=self.fc_kwargs,
fc_kwargs=self.fc_type,
)
self.down_proj._is_residual = True

Expand All @@ -170,7 +174,7 @@ def __init__(
self,
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN,
device: Optional[str] = None,
Expand All @@ -185,11 +189,12 @@ def __init__(
device=device,
bias=bias,
)

self.gate_proj = build_fc(
name=fc_type,
name=self.fc_type_name,
in_features=d_model,
out_features=self.up_proj.out_features,
fc_kwargs=self.fc_kwargs,
fc_kwargs=self.fc_type,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -199,7 +204,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def build_mptglu(
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
device: Optional[str] = None,
Expand All @@ -219,7 +224,7 @@ def build_mptglu(
def build_mptmlp(
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
device: Optional[str] = None,
Expand All @@ -239,7 +244,7 @@ def build_mptmlp(
def build_te_ln_mlp(
d_model: int,
expansion_ratio: Union[int, float],
fc_type: str = 'torch',
fc_type: Optional[dict[str, Any]] = None,
ffn_hidden_size: Optional[int] = None,
ffn_act_fn: Optional[dict] = None,
device: Optional[str] = None,
Expand Down Expand Up @@ -280,7 +285,7 @@ def build_torch_dmoe(
moe_normalize_expert_weights = kwargs.pop('moe_normalize_expert_weights')
uniform_expert_assignment = kwargs.pop('uniform_expert_assignment')

fc_type = kwargs.pop('fc_type', 'torch')
fc_type = kwargs.pop('fc_type', None)
del fc_type # Unused

if len(kwargs) > 0:
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/layers/layer_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def build_fc(
kwargs = {
'in_features': in_features,
'out_features': out_features,
**fc_kwargs,
**{k: v for k, v in fc_kwargs.items() if k != 'name'},
}

return construct_from_registry(
Expand Down
Loading
Loading