Skip to content

Commit

Permalink
switch to mcore's optimizer (#448)
Browse files Browse the repository at this point in the history
* switch to mcore's optimizer

Signed-off-by: dimapihtar <[email protected]>

* set fused adam optim for bert/t5/mt5

Signed-off-by: dimapihtar <[email protected]>

---------

Signed-off-by: dimapihtar <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
  • Loading branch information
dimapihtar and pablo-garay authored Dec 5, 2024
1 parent 9a3a397 commit 3dd3561
Show file tree
Hide file tree
Showing 100 changed files with 106 additions and 106 deletions.
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/baichuan2_13b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
grad_sync_dtype: bf16
lr: 1e-4
weight_decay: 0.1
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/baichuan2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
grad_sync_dtype: bf16
lr: 1e-4
weight_decay: 0.1
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ model:
short_seq_prob: 0.1 # Probability of producing a short sequence.

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 2e-4
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/chatglm_6b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/gpt3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 6e-4
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama2_13b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ model:
ub_tp_comm_overlap: false
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.0001
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama2_70b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ model:
batch_p2p_comm: true
gc_interval: 100
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.00015
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama3_70b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ model:
ranks: [0] # Global rank IDs to profile
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.00015
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ model:
ranks: [0] # Global rank IDs to profile
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/mixtral_3b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ model:
- 0
gen_shape: false
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.0001
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/mixtral_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ model:
- 0
gen_shape: false
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.0001
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/mt5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ model:
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 0.0001
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/qwen2_14b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/qwen2_4b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/qwen2_72b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/qwen2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/t5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ model:
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 0.0001
Expand Down
12 changes: 6 additions & 6 deletions auto_configurator/tests/base_configs_tests/test_base_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_gpt3_base_config(self):
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 6e-4
Expand Down Expand Up @@ -336,7 +336,7 @@ def test_llama_base_config(self):
ranks: [0] # Global rank IDs to profile
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down Expand Up @@ -594,7 +594,7 @@ def test_mixtral_base_config(self):
- 0
gen_shape: false
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.0001
weight_decay: 0.1
betas:
Expand Down Expand Up @@ -867,7 +867,7 @@ def test_t5_base_config(self):
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 0.0001
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def test_mt5_base_config(self):
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 0.0001
Expand Down Expand Up @@ -1263,7 +1263,7 @@ def test_bert_base_config(self):
short_seq_prob: 0.1 # Probability of producing a short sequence.
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 2e-4
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/baichuan2/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/chatglm/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/falcon/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/gpt3/custom_task.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ model:
num_classes: null

optim:
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 5e-6
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/gpt3/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ model:
num_classes: null

optim:
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/llama/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/mamba/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 2e-4
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/mistral/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/mixtral/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/mixtral/squad_8x22b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/qwen2/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/gemma/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/griffin/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-5
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/llama/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/nemotron/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/qwen2/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.01
betas:
Expand Down
4 changes: 2 additions & 2 deletions launcher_scripts/conf/rlhf_ppo/gpt3/2b_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ critic:
num_attributes: 1

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
Expand Down Expand Up @@ -261,7 +261,7 @@ actor:
seed: 1234

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/rlhf_rm/gpt3/2b_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ model:
checkpoint_name: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
Expand Down
Loading

0 comments on commit 3dd3561

Please sign in to comment.