diff --git a/deepspeed/pt/deepspeed_zero_optimizer.py b/deepspeed/pt/deepspeed_zero_optimizer.py index 3072e6bc75d2..b0848a8750f2 100755 --- a/deepspeed/pt/deepspeed_zero_optimizer.py +++ b/deepspeed/pt/deepspeed_zero_optimizer.py @@ -166,6 +166,7 @@ def __init__(self, if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") + logger.info(f"CPU Offload: {cpu_offload}") # The fused optimizer does all the work. We need this layer for two reason: # 1. maintain same user API from apex.fp16_utils # 2. keep common stuff here in case we need to add ne552w fused optimizer later @@ -1564,7 +1565,8 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states): dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) return dp_partitions[partition_id] else: - return all_partition_states[partition_id] + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return all_partition_states[0] # Restore base optimizer state from checkpoint by # 1) Merging optimizer state from checkpoints of all partitions diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json index 2a3b9ca5a0be..c3322eca8138 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json @@ -3,14 +3,9 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":1 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } + "stage": 1 }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "fp16": { "enabled": true, diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json index fde222a3cca2..f6a6db57daf2 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json @@ -3,17 +3,12 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":2, + "stage": 2, "reduce_bucket_size": 7000000, "allgather_bucket_size": 7000000, "reduce_scatter": true }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "fp16": { "enabled": true, diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2_offload.json b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2_offload.json new file mode 100755 index 000000000000..ad054d31bb66 --- /dev/null +++ b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2_offload.json @@ -0,0 +1,21 @@ +{ + "train_batch_size": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": 2, + "reduce_bucket_size": 7000000, + "allgather_bucket_size": 7000000, + "reduce_scatter": true, + "cpu_offload": true + }, + "zero_allow_untested_optimizer": true, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + } +} diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs8_no_zero.json b/tests/model/Megatron_GPT2/ds_config_func_bs8_no_zero.json index 99637973cd60..63b30c225753 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs8_no_zero.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs8_no_zero.json @@ -3,13 +3,7 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":0 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } + "stage": 0 }, "gradient_clipping": 1.0, "fp16": { diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json index 8d44659a9ee3..342fd665ccae 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json @@ -2,15 +2,10 @@ "train_batch_size": 8, "gradient_accumulation_steps": 1, "steps_per_print": 1, - "zero_optimization":{ - "stage":1 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } + "zero_optimization": { + "stage": 1 }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "fp16": { "enabled": true, diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2.json b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2.json index fde90e8274b8..0e2582fa102f 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2.json @@ -3,17 +3,12 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":2, + "stage": 2, "reduce_bucket_size": 7000000, "allgather_bucket_size": 7000000, "reduce_scatter": true }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "fp16": { "enabled": true, @@ -26,5 +21,4 @@ "partition_activations": true, "contiguous_memory_optimization": true } - } diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2_offload.json b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2_offload.json new file mode 100755 index 000000000000..5c66ed7cc585 --- /dev/null +++ b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2_offload.json @@ -0,0 +1,25 @@ +{ + "train_batch_size": 8, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": 2, + "reduce_bucket_size": 7000000, + "allgather_bucket_size": 7000000, + "reduce_scatter": true, + "cpu_offload": true + }, + "zero_allow_untested_optimizer": true, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": true + } +} diff --git a/tests/model/Megatron_GPT2/ds_config_func_scheduler.json b/tests/model/Megatron_GPT2/ds_config_func_scheduler.json index 60c810786bf0..2d2ab356e57c 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_scheduler.json +++ b/tests/model/Megatron_GPT2/ds_config_func_scheduler.json @@ -3,14 +3,9 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":2 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } + "stage": 2 }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "scheduler": { "type": "WarmupLR", @@ -20,7 +15,6 @@ "warmup_num_steps": 10 } }, - "fp16": { "enabled": true, "loss_scale": 0, diff --git a/tests/model/Megatron_GPT2/ds_config_perf_bs16.json b/tests/model/Megatron_GPT2/ds_config_perf_bs16.json index f160ccd8e610..a40f3e4c7d44 100644 --- a/tests/model/Megatron_GPT2/ds_config_perf_bs16.json +++ b/tests/model/Megatron_GPT2/ds_config_perf_bs16.json @@ -2,7 +2,10 @@ "train_batch_size": 16, "gradient_accumulation_steps": 1, "steps_per_print": 1, - "zero_optimization": 1, + "zero_optimization": { + "stage": 1 + }, + "zero_allow_untested_optimizer": true, "disable_allgather": true, "optimizer": { "type": "Adam", diff --git a/tests/model/Megatron_GPT2/ds_config_perf_bs32.json b/tests/model/Megatron_GPT2/ds_config_perf_bs32.json index 6e23fe687bc8..096a0d3645cd 100755 --- a/tests/model/Megatron_GPT2/ds_config_perf_bs32.json +++ b/tests/model/Megatron_GPT2/ds_config_perf_bs32.json @@ -3,8 +3,9 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":1 + "stage": 1 }, + "zero_allow_untested_optimizer": true, "disable_allgather": true, "optimizer": { "type": "Adam", diff --git a/tests/model/Megatron_GPT2/ds_config_perf_bs8.json b/tests/model/Megatron_GPT2/ds_config_perf_bs8.json index 514496958e14..e793e221e1e7 100644 --- a/tests/model/Megatron_GPT2/ds_config_perf_bs8.json +++ b/tests/model/Megatron_GPT2/ds_config_perf_bs8.json @@ -2,7 +2,10 @@ "train_batch_size": 8, "gradient_accumulation_steps": 1, "steps_per_print": 1, - "zero_optimization": 1, + "zero_optimization": { + "stage": 1 + }, + "zero_allow_untested_optimizer": true, "disable_allgather": true, "optimizer": { "type": "Adam", diff --git a/tests/model/Megatron_GPT2/ds_gpt2_test.sh b/tests/model/Megatron_GPT2/ds_gpt2_test.sh index 5c901f855a33..a8af44df9c7e 100755 --- a/tests/model/Megatron_GPT2/ds_gpt2_test.sh +++ b/tests/model/Megatron_GPT2/ds_gpt2_test.sh @@ -91,9 +91,9 @@ gpt_options=" \ ${ds_opt} \ ${zero_opt} \ " - +DEEPSPEED_PORT=29600 work_dir="../../../DeepSpeedExamples/Megatron-LM/" -run_cmd="(cd ${work_dir} && deepspeed --num_nodes $nodes --num_gpus $gpus pretrain_gpt2.py ${gpt_options})" +run_cmd="(cd ${work_dir} && deepspeed --master_port ${DEEPSPEED_PORT} --num_nodes $nodes --num_gpus $gpus pretrain_gpt2.py ${gpt_options})" echo ${run_cmd} eval ${run_cmd} diff --git a/tests/model/Megatron_GPT2/run_checkpoint_test.py b/tests/model/Megatron_GPT2/run_checkpoint_test.py index 116e58b98fa2..cf11af6c2ae4 100755 --- a/tests/model/Megatron_GPT2/run_checkpoint_test.py +++ b/tests/model/Megatron_GPT2/run_checkpoint_test.py @@ -97,6 +97,29 @@ def test_mp2_gpu4_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp2_gpu4_node1_with_zero2_offload(self): + test_config = { + "mp": 2, + "gpus": 4, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp2_gpu8_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp1_gpu2_load_gpu1_node1_with_zero1(self): test_config = { "mp": 1, @@ -110,7 +133,7 @@ def test_mp1_gpu2_load_gpu1_node1_with_zero1(self): "seq_length": 256, "heads": ATTN_HEADS, "deepspeed": True, - "tag": "ds_zero2", + "tag": "ds_zero1", "zero": True, "other_args": "", "checkpoint_name": "ckpt_mp1_gpu2_gpu1_w_zero1", @@ -133,7 +156,7 @@ def test_mp1_gpu2_load_gpu4_node1_with_zero1(self): "seq_length": 256, "heads": ATTN_HEADS, "deepspeed": True, - "tag": "ds_zero2", + "tag": "ds_zero1", "zero": True, "other_args": "", "checkpoint_name": "ckpt_mp1_gpu2_gpu4_w_zero1", @@ -166,6 +189,30 @@ def test_mp1_gpu2_load_gpu1_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp1_gpu2_load_gpu1_node1_with_zero2_offload(self): + test_config = { + "mp": 1, + "gpus": 2, + "load_gpus": 1, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp1_gpu2_gpu1_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp1_gpu2_load_gpu4_node1_with_zero2(self): test_config = { "mp": 1, @@ -189,6 +236,30 @@ def test_mp1_gpu2_load_gpu4_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp1_gpu2_load_gpu4_node1_with_zero2_offload(self): + test_config = { + "mp": 1, + "gpus": 2, + "load_gpus": 4, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp1_gpu2_gpu4_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp2_gpu4_load_gpu2_node1_with_zero1(self): test_config = { "mp": 2, @@ -258,6 +329,30 @@ def test_mp2_gpu4_load_gpu2_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp2_gpu4_load_gpu2_node1_with_zero2_offload(self): + test_config = { + "mp": 2, + "gpus": 4, + "load_gpus": 2, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp2_gpu4_gpu2_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp2_gpu2_load_gpu4_node1_with_zero2(self): test_config = { "mp": 2, @@ -281,6 +376,30 @@ def test_mp2_gpu2_load_gpu4_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp2_gpu2_load_gpu4_node1_with_zero2_offload(self): + test_config = { + "mp": 2, + "gpus": 2, + "load_gpus": 4, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp2_gpu2_gpu4_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp2_gpu4_node1_without_zero(self): test_config = { "mp": 2, @@ -306,7 +425,8 @@ def test_mp2_gpu4_node1_without_zero(self): def gen_name(self, test_config, prefix): save_dir = "checkpoint_test_logs" tag = test_config["tag"] - file_name = f"_{tag}.log" + checkpoint_name = test_config["checkpoint_name"] + file_name = f"_{tag}_{checkpoint_name}.log" return os.path.join(save_dir, prefix + file_name) def run_test(self, test_config, r_tol): @@ -334,10 +454,15 @@ def run_test(self, test_config, r_tol): except: print("No old checkpoint") + if "cpu_optimizer" in test_config and test_config["cpu_optimizer"]: + cpu_optimizer_flag = " --cpu-optimizer" + else: + cpu_optimizer_flag = "" + #-----------------Saving Checkpoint-----------------# - #building checkpoint arguments + # building checkpoint arguments test_config[ - "other_args"] = f"\"--save {checkpoint_folder} --save-interval {checkpoint_interval}\"" + "other_args"] = f"\"--save {checkpoint_folder} --save-interval {checkpoint_interval} {cpu_optimizer_flag}\"" prefix = "gpt2_saving_checkpoint" @@ -356,10 +481,11 @@ def run_test(self, test_config, r_tol): #-----------------Loading Checkpoint-----------------# - #building checkpoint arguments - test_config["other_args"] = f"\"--load {checkpoint_folder}\"" + # building checkpoint arguments + test_config[ + "other_args"] = f"\"--load {checkpoint_folder} {cpu_optimizer_flag} \"" - #set checkpoint load iteration + # set checkpoint load iteration try: cmd = f"echo {checkpoint_interval} > {checkpoint_name}/latest_checkpointed_iteration.txt" print(f"{self.id()} running cmd: {cmd}") @@ -411,20 +537,32 @@ def check_parity(self, base_file, test_file, r_tol): def checkpoint_suite(): suite = unittest.TestSuite() + suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero2')) + suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero2_offload')) # Shrink DP suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero2')) + suite.addTest( + GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero2_offload')) + suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero2')) + suite.addTest( + GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero2_offload')) # Expand DP suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero2')) + suite.addTest( + GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero2_offload')) + suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero2')) + suite.addTest( + GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero2_offload')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_without_zero'))