Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extendability refactors #1290

Merged
merged 19 commits into from
Jun 20, 2024
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

def build_finetuning_dataloader(
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
device_batch_size: Union[int, float],
milocress marked this conversation as resolved.
Show resolved Hide resolved
dataset: Dict[str, Any],
num_workers: int,
drop_last: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def build_streams(streams: Optional[Dict[str, Any]] = None,):

def build_text_dataloader(
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
device_batch_size: Union[int, float],
milocress marked this conversation as resolved.
Show resolved Hide resolved
dataset: Dict[str, Any],
drop_last: bool,
num_workers: int,
Expand Down
19 changes: 19 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,32 @@ def apply_ffn(
indices = None
if not self.use_pad_tok_in_ffn and attention_mask is not None:
assert unpad_input is not None
attention_mask = self.slice_attention_mask(attention_mask, seq_len)
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
m, indices, _, _ = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffn and attention_mask is not None:
assert pad_input is not None
n = pad_input(n, indices, batch_size, seq_len)
return n

def slice_attention_mask(
self,
attention_mask: torch.ByteTensor,
seq_len: int,
) -> torch.ByteTensor:
"""Slice attention mask to the correct size.

Can be overridden by subclasses to apply different slicing logic.

Args:
attention_mask (torch.ByteTensor): The attention mask.
seq_len (int): The sequence length.

Returns:
torch.ByteTensor: The sliced attention mask.
"""
return attention_mask


class FusedNormAttentionNorm(nn.Module):

Expand Down
25 changes: 11 additions & 14 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,12 @@
check_alibi_support,
is_flash_v2_installed,
)

# NOTE: All utils are imported directly even if unused so that
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.layer_builders import build_norm, build_fc, build_ffn # type: ignore (see note)
from llmfoundry.models.layers.dmoe import dMoE # type: ignore (see note)
from llmfoundry.layers_registry import norms # type: ignore (see note)
from llmfoundry.utils.registry_utils import construct_from_registry # type: ignore (see note)
from llmfoundry.models.utils.config_defaults import (
attn_config_defaults,
fc_type_defaults,
ffn_config_defaults,
init_config_defaults,
fc_type_defaults,
) # type: ignore (see note)
)


class MPTConfig(PretrainedConfig):
Expand Down Expand Up @@ -196,6 +186,13 @@ def _set_config_defaults(
)
return config

def validate_attention_config(self) -> None:
if 'seq_parallel_world_size' in self.attn_config and self.attn_config[
'seq_parallel_world_size'] is None:
del self.attn_config['seq_parallel_world_size']
if self.attn_config.get('seq_parallel_world_size', 1) > 1:
raise NotImplementedError('Sequence Parallelism is not supported.')

def _validate_config(self) -> None:
# set config defaults
self.attn_config = self._set_config_defaults(
Expand Down Expand Up @@ -336,5 +333,5 @@ def _validate_config(self) -> None:
raise ImportError(
'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6',
)
if (self.attn_config.get('seq_parallel_world_size', 1) or 1) > 1:
raise NotImplementedError('Sequence Parallelism is not supported.')

self.validate_attention_config()
7 changes: 6 additions & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
# isort: off
from llmfoundry.models.layers.fc import fcs # type: ignore
from llmfoundry.models.utils.param_init_fns import generic_param_init_fn_ # type: ignore
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore
# isort: on
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -425,6 +426,10 @@ def __init__(self, config: MPTConfig):
log.debug(self)
log.debug(f'Using {self.config.init_config["name"]} initialization.')

@property
def block_class(self) -> Type[MPTBlock]:
return MPTBlock

def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
"""Construct the nn.ModuleList with the Transformer blocks.

Expand All @@ -437,7 +442,7 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
block_args = self.extract_block_args(config.to_dict())

return nn.ModuleList([
MPTBlock(
self.block_class(
device=config.init_device,
**block_args,
) for _ in range(config.n_layers)
Expand Down
22 changes: 22 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,27 @@
description=_icl_datasets_description,
)

_config_transforms_description = (
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
"""The config_transforms registry is used to register functions that transform the training config

The config will be transformed before it is used anywhere else. Note: By default ALL registered transforms will be applied to the train config
and NONE to the eval config. Each transform should return the modified config.

Args:
cfg (Dict[str, Any]): The training config.

Returns:
cfg (Dict[str, Any]): The modified training config.
"""
)
config_transforms = create_registry(
'llmfoundry',
'config_transforms',
generic_type=Callable[[Dict[str, Any]], Dict[str, Any]],
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
entry_points=True,
description=_config_transforms_description,
)

__all__ = [
'loggers',
'callbacks',
Expand All @@ -245,4 +266,5 @@
'attention_implementations',
'fcs',
'icl_datasets',
'config_transforms',
]
6 changes: 6 additions & 0 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.registry import config_transforms
from llmfoundry.utils.builders import (
build_algorithm,
build_callback,
Expand Down Expand Up @@ -59,6 +60,11 @@
experimental_function,
)

config_transforms.register(
'update_batch_size_info',
func=update_batch_size_info,
)

__all__ = [
'build_algorithm',
'build_callback',
Expand Down
11 changes: 8 additions & 3 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def build_evaluators(
eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]],
*,
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
device_eval_batch_size: Union[int, float],
icl_seq_len: int,
icl_subset_num_batches: Optional[int],
) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
Expand All @@ -79,6 +79,10 @@ def build_evaluators(
logger_keys = []
eval_gauntlet_callback = None
if icl_tasks_config is not None:
if not isinstance(device_eval_batch_size, int):
raise ValueError(
'device_eval_batch_size should be an int for icl tasks.',
)
icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks_config,
eval_gauntlet_config,
Expand All @@ -95,7 +99,7 @@ def build_evaluators(
def build_eval_loaders(
eval_loader_config: Union[Dict[str, Any], List[Dict[str, Any]]],
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
device_eval_batch_size: Union[int, float],
) -> List[Evaluator]:
evaluators: List[Evaluator] = []
if isinstance(eval_loader_config, list):
Expand All @@ -122,7 +126,8 @@ def build_eval_loaders(
# Load the eval data to fail fast. metrics will get added
# later in add_metrics_to_eval_loaders, after the model is loaded
metric_names=[],
device_eval_microbatch_size=device_eval_batch_size,
# TODO: Fix type in Composer
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
device_eval_microbatch_size=device_eval_batch_size, # type: ignore
)
evaluators.append(eval_loader)
return evaluators
Expand Down
82 changes: 66 additions & 16 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from llmfoundry.layers_registry import ffns_with_megablocks
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.registry import config_transforms

log = logging.getLogger(__name__)

Expand All @@ -48,7 +49,7 @@ class EvalConfig:
# Eval Config required parameters:
models: List[Dict[str, Any]] = MISSING
max_seq_len: int = MISSING
device_eval_batch_size: int = MISSING
device_eval_batch_size: Union[int, float] = MISSING

# Eval Config optional parameters:
code_paths: Optional[List[str]] = None
Expand Down Expand Up @@ -101,7 +102,7 @@ class TrainConfig:
scheduler: Dict[str, Any] = MISSING
train_loader: Dict[str, Any] = MISSING
device_train_batch_size: Union[int, float] = MISSING
device_eval_batch_size: int = MISSING
device_eval_batch_size: Union[int, float] = MISSING
max_duration: Union[int, str] = MISSING
eval_interval: Union[int, str] = MISSING
max_seq_len: int = MISSING
Expand Down Expand Up @@ -160,7 +161,7 @@ class TrainConfig:
save_ignore_keys: Optional[List[str]] = None

# Dataloader
device_train_microbatch_size: Union[str, int] = 'auto'
device_train_microbatch_size: Union[str, int, float] = 'auto'
global_train_batch_size: Optional[int] = None

# Eval dataloader
Expand Down Expand Up @@ -238,12 +239,60 @@ def to_container(
T = TypeVar('T')


def apply_transforms_to_config(
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
cfg: Dict[str, Any],
transforms: Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]],
List[str], str]],
) -> Dict[str, Any]:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
"""Applies a list of transforms to a config.

Args:
cfg (Dict[str, Any]): The config to transform.
transforms (Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]], List[str], str]]): A list of
transform functions or strings representing transform functions to apply to the config. If a single string
with the value ``all`` is provided, all registered transforms will be applied.

Returns:
Dict[str, Any]: The transformed config.
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
"""
if transforms is None or (
isinstance(transforms, list) and len(transforms) == 0
):
return cfg

transform_functions = []
if isinstance(transforms, list):
for transform in transforms:
if isinstance(transform, str):
transform_functions.append(config_transforms.get(transform))
elif callable(transform):
transform_functions.append(transform)
else:
raise ValueError(
f'Invalid transform: {transform}. Must be a string or callable.',
)
elif isinstance(transforms, str) and transforms == 'all':
transform_functions = [
config_transforms.get(transform)
for transform in config_transforms.get_all()
]
else:
raise ValueError(
f'Invalid transforms: {transforms}. Must be a list of strings or callables, or ``all``.',
)

for transform in transform_functions:
cfg = transform(cfg)

return cfg


def make_dataclass_and_log_config(
cfg: DictConfig,
dataclass_constructor: Callable[..., T],
dataclass_fields: Set[str],
transforms: Optional[List[Callable[[Dict[str, Any]], Dict[str,
Any]]]] = None,
transforms: Optional[Union[List[Callable[[Dict[str, Any]], Dict[str, Any]]],
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
List[str], str]] = None,
icl_tasks_required: bool = False,
) -> Tuple[Dict[str, Any], T]:
"""Converts a DictConfig to a dataclass and creates a logged config."""
Expand Down Expand Up @@ -281,8 +330,10 @@ def make_dataclass_and_log_config(
logged_cfg: Dict[str, Any] = copy.deepcopy(unstructured_config)

# Apply transforms to the unstructured config before constructing dataclass
for transform in transforms or []:
unstructured_config = transform(unstructured_config)
unstructured_config = apply_transforms_to_config(
unstructured_config,
transforms,
)

logged_cfg.update(unstructured_config, merge=True)

Expand Down Expand Up @@ -367,20 +418,20 @@ def calculate_batch_size_info(
data_replication_degree: int = 1,
) -> Tuple[Union[int, float], Union[int, float, Literal['auto']], Union[
int, Literal['auto']]]:
if dist.get_world_size() % data_replication_degree != 0:

world_size = dist.get_world_size()
if world_size % data_replication_degree != 0:
raise ValueError(
f'World size {dist.get_world_size()} is not divisible by data replication degree {data_replication_degree}.',
f'World size {world_size} is not divisible by data replication degree {data_replication_degree}.',
)
if global_batch_size % (
dist.get_world_size() // data_replication_degree
) != 0:
if global_batch_size % (world_size // data_replication_degree) != 0:
raise ValueError(
f'Global batchsize {global_batch_size} is not divisible by {(dist.get_world_size() // data_replication_degree)=} '
f'Global batchsize {global_batch_size} is not divisible by {(world_size // data_replication_degree)=} '
+
'as a result, the batch size would be truncated, please adjust `global_batch_size` '
+ f'to be divisible by world size, {dist.get_world_size()}.',
+ f'to be divisible by world size, {world_size}.',
)
device_batch_size = global_batch_size / dist.get_world_size()
device_batch_size = global_batch_size / world_size
if device_batch_size == round(device_batch_size):
device_batch_size = round(device_batch_size)
if device_microbatch_size == 'auto':
Expand All @@ -401,7 +452,6 @@ def calculate_batch_size_info(
return device_batch_size, device_microbatch_size, device_grad_accum


# Coming soon: this conversion math will be done inside Composer Trainer
def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]:
data_replication_degree = 1
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(
Expand Down
Loading
Loading