Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump composer to 0.24.1 + FSDP config device_mesh deprecation #1487

Merged
merged 11 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,18 +533,35 @@ def process_init_device(model_cfg: dict[str, Any], fsdp_config: Optional[dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('load_monolith_rank0_only', True)

# Set ffn_config.device_mesh to fsdp_config.device_mesh
if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[
# Set ffn_config.device_mesh using fsdp_config
if fsdp_config is not None and 'ffn_config' in model_cfg and model_cfg[
'ffn_config'].get('ffn_type', None) in ffns_with_megablocks:
# Raise ValueError if not using device mesh with MoE expert parallelism
if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get(
'moe_world_size',
1,
) > 1:
raise ValueError(
'device_mesh must be specified in fsdp_config when using MoE with moe_world_size > 1.',
)
model_cfg['ffn_config']['device_mesh'] = fsdp_config['device_mesh']
shard_degree = fsdp_config.get('data_parallel_shard_degree', None)
replicate_degree = fsdp_config.get(
'data_parallel_replicate_degree',
None,
)

if shard_degree is None and replicate_degree is None:
# Default to sharding over all gpus.
shard_degree = dist.get_world_size()
device_mesh_cfg = [shard_degree]
else:
if shard_degree is None:
# Shard degree is not specified, so calculate it from replicate degree
assert isinstance(replicate_degree, int)
shard_degree = dist.get_world_size() // replicate_degree
elif replicate_degree is None:
# Replicate degree is not specified, so calculate it from shard degree
assert isinstance(shard_degree, int)
replicate_degree = dist.get_world_size() // shard_degree

if replicate_degree == 1:
device_mesh_cfg = [shard_degree]
else:
device_mesh_cfg = [replicate_degree, shard_degree]

model_cfg['ffn_config']['device_mesh'] = device_mesh_cfg

# No mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
Expand Down
3 changes: 1 addition & 2 deletions scripts/eval/yamls/dbrx-gauntlet/dbrx-18b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ models:
precision: amp_bf16
fsdp_config:
verbose: false
device_mesh:
- 8
data_parallel_shard_degree: 8
mixed_precision: PURE
state_dict_type: sharded
use_orig_params: true
Expand Down
3 changes: 1 addition & 2 deletions scripts/eval/yamls/dbrx-gauntlet/dbrx-35b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ models:
precision: amp_bf16
fsdp_config:
verbose: false
device_mesh:
- 8
data_parallel_shard_degree: 8
mixed_precision: PURE
state_dict_type: sharded
use_orig_params: true
Expand Down
3 changes: 1 addition & 2 deletions scripts/eval/yamls/dbrx-gauntlet/dbrx-73b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ models:
precision: amp_bf16
fsdp_config:
verbose: false
device_mesh:
- 8
data_parallel_shard_degree: 8
mixed_precision: PURE
state_dict_type: sharded
use_orig_params: true
Expand Down
3 changes: 1 addition & 2 deletions scripts/eval/yamls/dbrx-gauntlet/dbrx-9b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ models:
precision: amp_bf16
fsdp_config:
verbose: false
device_mesh:
- 8
data_parallel_shard_degree: 8
mixed_precision: PURE
state_dict_type: sharded
use_orig_params: true
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
]

install_requires = [
'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.23.4,<0.24',
'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.24.0,<0.25',
'mlflow>=2.14.1,<2.16',
'accelerate>=0.25,<0.34', # for HF inference `device_map`
'transformers>=4.43.2,<4.44',
Expand Down Expand Up @@ -91,14 +91,14 @@
]

extra_deps['databricks'] = [
'mosaicml[databricks]>=0.23.4,<0.24',
'mosaicml[databricks]>=0.24.0,<0.25',
'databricks-sql-connector>=3,<4',
'databricks-connect==14.1.0',
'lz4>=4,<5',
]

extra_deps['tensorboard'] = [
'mosaicml[tensorboard]>=0.23.4,<0.24',
'mosaicml[tensorboard]>=0.24.0,<0.25',
]

# Flash 2 group kept for backwards compatibility
Expand All @@ -109,7 +109,7 @@
extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2'])

extra_deps['peft'] = [
'mosaicml[peft]>=0.23.4,<0.24',
'mosaicml[peft]>=0.24.0,<0.25',
]

extra_deps['openai'] = [
Expand Down
7 changes: 4 additions & 3 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,12 +1283,13 @@ def test_mptmoe_huggingface_conversion_callback(
'activation_checkpointing_reentrant': False,
'activation_cpu_offload': False,
'limit_all_gathers': True,
'device_mesh': [1, 4] if sharding_strategy == 'HYBRID_SHARD' else [
4,
],
'use_orig_params': True,
'data_parallel_shard_degree': 4,
}

if sharding_strategy == 'HYBRID_SHARD':
fsdp_config['data_parallel_shard_degree'] = 1

tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small')
tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl')
if dist.get_global_rank() == 0:
Expand Down
57 changes: 56 additions & 1 deletion tests/utils/test_config_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.utils.config_utils import update_config_with_batch_size_info
from unittest.mock import patch

import pytest

from llmfoundry.utils.config_utils import (
process_init_device,
update_config_with_batch_size_info,
)


def test_update_config_with_batch_size_info():
Expand All @@ -13,3 +20,51 @@ def test_update_config_with_batch_size_info():
assert config['device_train_microbatch_size'] == 2
assert config['device_train_grad_accum'] == 3
assert config['device_eval_batch_size'] == 2


@pytest.mark.parametrize('shard_degree_specified', [True, False])
@pytest.mark.parametrize('replicate_degree_specified', [True, False])
@pytest.mark.parametrize('should_shard_only', [True, False])
def test_moe_fsdp_config_ffn_config(
shard_degree_specified: bool,
replicate_degree_specified: bool,
should_shard_only: bool,
):
model_cfg = {
'moe_world_size': 4,
'lbl_process_group': 'not_real',
'fc_type': 'torch',
'ffn_config': {
'ffn_type': 'mb_moe',
},
}
fsdp_config = {}
if shard_degree_specified and replicate_degree_specified:
if should_shard_only:
fsdp_config['data_parallel_shard_degree'] = 8
fsdp_config['data_parallel_replicate_degree'] = 1
else:
fsdp_config['data_parallel_shard_degree'] = 4
fsdp_config['data_parallel_replicate_degree'] = 2
elif shard_degree_specified:
if should_shard_only:
fsdp_config['data_parallel_shard_degree'] = 8
else:
fsdp_config['data_parallel_shard_degree'] = 4
elif replicate_degree_specified:
if should_shard_only:
fsdp_config['data_parallel_replicate_degree'] = 1
else:
fsdp_config['data_parallel_replicate_degree'] = 2

# Ensure the ffn_config's device_mesh is set correctly using the fsdp_config
with patch('composer.utils.dist.get_world_size', return_value=8
), patch('catalogue.Registry.__contains__', return_value=True):
_ = process_init_device(model_cfg, fsdp_config)

if should_shard_only or (
not shard_degree_specified and not replicate_degree_specified
):
assert model_cfg['ffn_config']['device_mesh'] == [8]
else:
assert model_cfg['ffn_config']['device_mesh'] == [2, 4]
Loading