Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Found optimizer configured in the DeepSpeed config, but no scheduler. Please configure a scheduler in the DeepSpeed config. #24359

Closed
luohao123 opened this issue Jun 19, 2023 · 17 comments
Labels

Comments

@luohao123
Copy link

luohao123 commented Jun 19, 2023

ValueError: Found optimizer configured in the DeepSpeed config, but no scheduler. Please configure a scheduler in the DeepSpeed config.

Am using --warmup_ratio 0.03 --lr_scheduler_type "cosine" \

Here, and I didn't found a properly shceduler in deepspeed ssame as cosine, what should to set?

@amyeroberts
Copy link
Collaborator

Hi @luohao123,

So that we can help you, could you follow the issue template and provide a minimal code snippet to reproduce the error and the running environment: run transformers-cli env in the terminal and copy-paste the output?

cc @pacman100

@jackapbutler
Copy link

jackapbutler commented Jun 20, 2023

TLDR; if you're in a rush, downgrading to version <4.30 (4.29.2) worked for me

I've had the same issue 👇
I believe the previous behaviour allowed you to not include any DeepSpeed configuration scheduler key and the one specified in your TrainerArguments would be used. Now it seems you have to include the corresponding scheduler between DeepSpeed and Hugging Face Trainer.

i.e.

DeepSpeed scheduler Trainer scheduler Resulting scheduler
WarmupLR constant_with_warmup constant_with_warmup
WarmupDecayLR linear linear

whereas before you could just ignore the first column and leave it blank to get the same result

DeepSpeed scheduler Trainer scheduler Resulting scheduler
constant_with_warmup constant_with_warmup
linear linear

personally, I found it handier before where I only had to specify the scheduler in one place rather than tracking this over a DeepSpeed config and a Trainer config which are generally separate objects.

@pacman100
Copy link
Contributor

Hello, the supported combinations now are:

  1. Trainer optimizer + Trainer scheduler - Don't specify these in the DS config and use trainer args
  2. DeepSpeed optimizer + DeeepSpeed Scheduler - Specify both in DeepSpeed config and no need to use/specify them via Trainer args (@jackapbutler, please note this as you happen to be doing both)
  3. Trainer optimizer + DeepSpeed Scheduler - Don't specify optimizer in DS config; only set the scheduler there. Don't specify the scheduler via Trainer args.

@luohao123, the case you want is DeepSpeed Optimizer + Trainer Scheduler which isn't supported now. The suggested approach in your case would be to use Trainer optimizer + Trainer scheduler (Settting 1. above).

Hope this helps.

@luohao123
Copy link
Author

luohao123 commented Jun 21, 2023

@pacman100 I actually got some errors when specifci via trainingargs with cosine scheduler while not specific in deepspeed config:

│ ❱  485 │   │   self.initialize_optimizer_states()                                                │
│    486 │   │   see_memory_usage("After initializing optimizer states", force=True)               │
│    487 │   │                                                                                     │
│    488 │   │   if dist.get_rank() == 0:                                                          │
│                                                                                                  │
│ /root/anaconda3/lib/python3.10/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:620 in      │
│ initialize_optimizer_states                                                                      │
│                                                                                                  │
│    617 │   │   if isinstance(self.optimizer, torch.optim.Adagrad):                               │
│    618 │   │   │   self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, *  │
│    619 │   │   else:                                                                             │
│ ❱  620 │   │   │   self.optimizer.step()                                                         │
│    621 │   │                                                                                     │
│    622 │   │   if not self.cpu_offload:                                                          │
│    623 │   │   │   for group in self.single_partition_of_fp32_groups:                            │
│                                                                                                  │
│ /root/anaconda3/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:69 in wrapper           │
│                                                                                                  │
│     66 │   │   │   │   instance = instance_ref()                                                 │
│     67 │   │   │   │   instance._step_count += 1                                                 │
│     68 │   │   │   │   wrapped = func.__get__(instance, cls)                                     │
│ ❱   69 │   │   │   │   return wrapped(*args, **kwargs)                                           │
│     70 │   │   │                                                                                 │
│     71 │   │   │   # Note that the returned function here is no longer a bound method,           │
│     72 │   │   │   # so attributes like `__func__` and `__self__` no longer exist.               │
│                                                                                                  │
│ /root/anaconda3/lib/python3.10/site-packages/torch/optim/optimizer.py:280 in wrapper             │
│                                                                                                  │
│   277 │   │   │   │   │   │   │   raise RuntimeError(f"{func} must return None or a tuple of (   │
│   278 │   │   │   │   │   │   │   │   │   │   │      f"but got {result}.")                       │
│   279 │   │   │   │                                                                              │
│ ❱ 280 │   │   │   │   out = func(*args, **kwargs)                                                │
│   281 │   │   │   │   self._optimizer_step_code()                                                │
│   282 │   │   │   │                                                                              │
│   283 │   │   │   │   # call optimizer step post hooks                                           │
│                                                                                                  │
│ /root/anaconda3/lib/python3.10/site-packages/torch/optim/optimizer.py:33 in _use_grad            │
│                                                                                                  │
│    30 │   │   prev_grad = torch.is_grad_enabled()                                                │
│    31 │   │   try:                                                                               │
│    32 │   │   │   torch.set_grad_enabled(self.defaults['differentiable'])                        │
│ ❱  33 │   │   │   ret = func(self, *args, **kwargs)                                              │
│    34 │   │   finally:                                                                           │
│    35 │   │   │   torch.set_grad_enabled(prev_grad)                                              │
│    36 │   │   return ret                                                                         │
│                                                                                                  │
│ /root/anaconda3/lib/python3.10/site-packages/torch/optim/adamw.py:171 in step                    │
│                                                                                                  │
│   168 │   │   │   │   state_steps,                                                               │
│   169 │   │   │   )                                                                              │
│   170 │   │   │                                                                                  │
│ ❱ 171 │   │   │   adamw(                                                                         │
│   172 │   │   │   │   params_with_grad,                                                          │
│   173 │   │   │   │   grads,                                                                     │
│   174 │   │   │   │   exp_avgs,                                                                  │
│                                                                                                  │
│ /root/anaconda3/lib/python3.10/site-packages/torch/optim/adamw.py:321 in adamw                   │
│                                                                                                  │
│   318 │   else:                                                                                  │
│   319 │   │   func = _single_tensor_adamw                                                        │
│   320 │                                                                                          │
│ ❱ 321 │   func(                                                                                  │
│   322 │   │   params,                                                                            │
│   323 │   │   grads,                                                                             │
│   324 │   │   exp_avgs,                                                                          │
│                                                                                                  │
│ /root/anaconda3/lib/python3.10/site-packages/torch/optim/adamw.py:564 in _multi_tensor_adamw     │
│                                                                                                  │
│   561 │   │   │   │   torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)            │
│   562 │   │   │   │   denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)                       │
│   563 │   │   │   else:                                                                          │
│ ❱ 564 │   │   │   │   exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)                  │
│   565 │   │   │   │   torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)                │
│   566 │   │   │   │   denom = torch._foreach_add(exp_avg_sq_sqrt, eps)                           │
│   567                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Which is not right, A100, can u take a llok?

this is my ds config:

{
  "zero_allow_untested_optimizer": true,
  "fp16": {
    "enabled": "auto",
    "opt_level": "O2",
    "initial_scale_power": 16,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1,
    "loss_scale": 0
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "overlap_comm": false,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "contiguous_gradients": true
  },
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto"
}

this is my training args:

CUDA_VISIBLE_DEVICES=2,3 deepspeed --master_port 61000 train_full.py \
    --data_path ./data/train_data.json \
    --model_name_or_path ./checkpoints/baichuan-7B/ \
    --per_device_train_batch_size 4 --output_dir out/bc_full \
    --bf16 --num_train_epochs 3 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 16 \
    --learning_rate 2e-5 --weight_decay 0. \
    --warmup_ratio 0.03 --lr_scheduler_type "cosine" \
    --model_max_length 1024 \
    --logging_steps 50 \
    --lazy_preprocess True \
    --deepspeed configs/ds_s2_fschat.json

what did wrong???

@pacman100
Copy link
Contributor

Hello @luohao123, please provide minimal reproducible example for further deep dive. Things work fine for me with official example:

ds config:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "bf16": {
        "enabled": "auto"
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

Command:

cd transformers
export TASK_NAME=mrpc
CUDA_VISIBLE_DEVICES=2,3  deepspeed ./examples/pytorch/text-classification/run_glue.py --model_name_or_path bert-base-cased --task_name $TASK_NAME --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 16 --learning_rate 5e-5 --num_train_epochs 3 --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --deepspeed ds_config_zero2.json --lr_scheduler_type "cosine"

output logs:

[2023-06-22 09:47:48,765] [INFO] [config.py:964:print]   zero_enabled ................. True
[2023-06-22 09:47:48,765] [INFO] [config.py:964:print]   zero_force_ds_cpu_optimizer .. True
[2023-06-22 09:47:48,765] [INFO] [config.py:964:print]   zero_optimization_stage ...... 2
[2023-06-22 09:47:48,765] [INFO] [config.py:950:print_user_config]   json = {
    "fp16": {
        "enabled": false, 
        "loss_scale": 0, 
        "loss_scale_window": 1000, 
        "initial_scale_power": 16, 
        "hysteresis": 2, 
        "min_loss_scale": 1
    }, 
    "bf16": {
        "enabled": false
    }, 
    "zero_optimization": {
        "stage": 2, 
        "offload_optimizer": {
            "device": "cpu", 
            "pin_memory": true
        }, 
        "allgather_partitions": true, 
        "allgather_bucket_size": 2.000000e+08, 
        "overlap_comm": true, 
        "reduce_scatter": true, 
        "reduce_bucket_size": 2.000000e+08, 
        "contiguous_gradients": true
    }, 
    "gradient_accumulation_steps": 1, 
    "gradient_clipping": 1.0, 
    "steps_per_print": inf, 
    "train_batch_size": 32, 
    "train_micro_batch_size_per_gpu": 16, 
    "wall_clock_breakdown": false, 
    "zero_allow_untested_optimizer": true
}
Using /raid/sourab/.cache/huggingface/torch_extensions/py311_cu118 as PyTorch extensions root...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.00022840499877929688 seconds
[INFO|trainer.py:1680] 2023-06-22 09:47:48,766 >> ***** Running training *****
[INFO|trainer.py:1681] 2023-06-22 09:47:48,766 >>   Num examples = 3,668
[INFO|trainer.py:1682] 2023-06-22 09:47:48,766 >>   Num Epochs = 3
[INFO|trainer.py:1683] 2023-06-22 09:47:48,766 >>   Instantaneous batch size per device = 16
[INFO|trainer.py:1684] 2023-06-22 09:47:48,766 >>   Total train batch size (w. parallel, distributed & accumulation) = 32
[INFO|trainer.py:1685] 2023-06-22 09:47:48,766 >>   Gradient Accumulation steps = 1
[INFO|trainer.py:1686] 2023-06-22 09:47:48,766 >>   Total optimization steps = 345
[INFO|trainer.py:1687] 2023-06-22 09:47:48,766 >>   Number of trainable parameters = 108,311,810
[INFO|integrations.py:727] 2023-06-22 09:47:48,767 >> Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
wandb: Currently logged in as: smangrul. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.4
wandb: Run data is saved locally in /home/sourab/transformers/wandb/run-20230622_094749-h2mion2e
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run rose-vortex-320
wandb: ⭐️ View project at https://wandb.ai/smangrul/huggingface
wandb: 🚀 View run at https://wandb.ai/smangrul/huggingface/runs/h2mion2e
  0%|                                                                                                  | 0/345 [00:00<?, ?it/s]/home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1829: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/conda/conda-bld/pytorch_1687280020902/work/torch/csrc/tensor/python_tensor.cpp:83.)
  overflow_gpu = get_accelerator().ByteTensor([overflow])
/home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/deepspeed/runtime/zero/stage_1_and_2.py:1829: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at /opt/conda/conda-bld/pytorch_1687280020902/work/torch/csrc/tensor/python_tensor.cpp:83.)
  overflow_gpu = get_accelerator().ByteTensor([overflow])
100%|████████████████████████████████████████████████████████████████████████████████████████| 345/345 [00:57<00:00,  6.13it/s][INFO|trainer.py:1924] 2023-06-22 09:48:49,820 >> 

Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 61.0539, 'train_samples_per_second': 180.234, 'train_steps_per_second': 5.651, 'train_loss': 0.4465487715126812, 'epoch': 3.0}
100%|████████████████████████████████████████████████████████████████████████████████████████| 345/345 [00:57<00:00,  6.03it/s]
[INFO|trainer.py:2832] 2023-06-22 09:48:49,823 >> Saving model checkpoint to /tmp/mrpc/
[INFO|configuration_utils.py:458] 2023-06-22 09:48:49,824 >> Configuration saved in /tmp/mrpc/config.json
[INFO|modeling_utils.py:1845] 2023-06-22 09:48:50,616 >> Model weights saved in /tmp/mrpc/pytorch_model.bin
[INFO|tokenization_utils_base.py:2215] 2023-06-22 09:48:50,617 >> tokenizer config file saved in /tmp/mrpc/tokenizer_config.json
[INFO|tokenization_utils_base.py:2222] 2023-06-22 09:48:50,617 >> Special tokens file saved in /tmp/mrpc/special_tokens_map.json
***** train metrics *****
  epoch                    =        3.0
  train_loss               =     0.4465
  train_runtime            = 0:01:01.05
  train_samples            =       3668
  train_samples_per_second =    180.234
  train_steps_per_second   =      5.651
06/22/2023 09:48:50 - INFO - __main__ - *** Evaluate ***
[INFO|trainer.py:769] 2023-06-22 09:48:50,645 >> The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence1, sentence2, idx. If sentence1, sentence2, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
[INFO|trainer.py:3106] 2023-06-22 09:48:50,646 >> ***** Running Evaluation *****
[INFO|trainer.py:3108] 2023-06-22 09:48:50,646 >>   Num examples = 408
[INFO|trainer.py:3111] 2023-06-22 09:48:50,646 >>   Batch size = 8
100%|██████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 52.94it/s]
***** eval metrics *****
  epoch                   =        3.0
  eval_accuracy           =     0.8431
  eval_combined_score     =     0.8664
  eval_f1                 =     0.8897
  eval_loss               =     0.3868
  eval_runtime            = 0:00:00.51
  eval_samples            =        408
  eval_samples_per_second =     797.59
  eval_steps_per_second   =     50.827
wandb: Waiting for W&B process to finish... (success).
[2023-06-22 09:48:52,926] [INFO] [launch.py:347:main] Process 3002010 exits successfully.
wandb: 
wandb: Run history:
wandb:                  eval/accuracy ▁
wandb:            eval/combined_score ▁
wandb:                        eval/f1 ▁
wandb:                      eval/loss ▁
wandb:                   eval/runtime ▁
wandb:        eval/samples_per_second ▁
wandb:          eval/steps_per_second ▁
wandb:                    train/epoch ▁▁
wandb:              train/global_step ▁▁
wandb:               train/total_flos ▁
wandb:               train/train_loss ▁
wandb:            train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb:   train/train_steps_per_second ▁
wandb: 
wandb: Run summary:
wandb:                  eval/accuracy 0.84314
wandb:            eval/combined_score 0.8664
wandb:                        eval/f1 0.88966
wandb:                      eval/loss 0.38684
wandb:                   eval/runtime 0.5115
wandb:        eval/samples_per_second 797.59
wandb:          eval/steps_per_second 50.827
wandb:                    train/epoch 3.0
wandb:              train/global_step 345
wandb:               train/total_flos 726186493739008.0
wandb:               train/train_loss 0.44655
wandb:            train/train_runtime 61.0539
wandb: train/train_samples_per_second 180.234
wandb:   train/train_steps_per_second 5.651
wandb: 
wandb: 🚀 View run rose-vortex-320 at: https://wandb.ai/smangrul/huggingface/runs/h2mion2e
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20230622_094749-h2mion2e/logs
[2023-06-22 09:49:01,927] [INFO] [launch.py:347:main] Process 3002009 exits successfully.

@luohao123
Copy link
Author

@pacman100 thank u, let me try your config and have a test again, I notice your config are not exactly as mine.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Dominic789654
Copy link

Hello, the supported combinations now are:

  1. Trainer optimizer + Trainer scheduler - Don't specify these in the DS config and use trainer args
  2. DeepSpeed optimizer + DeeepSpeed Scheduler - Specify both in DeepSpeed config and no need to use/specify them via Trainer args (@jackapbutler, please note this as you happen to be doing both)
  3. Trainer optimizer + DeepSpeed Scheduler - Don't specify optimizer in DS config; only set the scheduler there. Don't specify the scheduler via Trainer args.

@luohao123, the case you want is DeepSpeed Optimizer + Trainer Scheduler which isn't supported now. The suggested approach in your case would be to use Trainer optimizer + Trainer scheduler (Settting 1. above).

Hope this helps.

Hi, I want to know if I use setting 1, will the optimizer utilize DeepSpeed's cpuAdam?

@pacman100
Copy link
Contributor

Hi, I want to know if I use setting 1, will the optimizer utilize DeepSpeed's cpuAdam?

Yes, by default zero_force_ds_cpu_optimizer is set to True if not explicitly specified in the ds_config. As such, it will leverage the DeepSpeed's cpuAdam when offloading as it is strongly recommended by DeepSpeed team

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@michaelroyzen
Copy link

I'm trying to use DeepSpeed optimizer + Trainer scheduler because DeepSpeed has the most best optimizer (fused Adam) and Trainer has the best scheduler for my use case (cosine). DeepSpeed does not support cosine. Why was DeepSpeed optimizer + Trainer scheduler deprecate without any warning? I think this is a mistake and that you should reconsider @pacman100.

@pacman100
Copy link
Contributor

Hello @michaelroyzen, the PRs #25863 and huggingface/accelerate#1909 should bring back the support for DeepSpeed optimizer + Trainer scheduler. Could you try it out and let us know.

@michaelroyzen
Copy link

Seems to work well so far @pacman100. Thanks!

@awasthiabhijeet
Copy link

Hi @pacman100 ,
Is PR #25863 part of the latest transformers version?

I still observe the following error, despite using the ds_config_z3_ds_optim_hf_scheduler.json.
ValueError: Found optimizer configured in the DeepSpeed config, but no scheduler. Please configure a scheduler in the DeepSpeed config.

@pacman100
Copy link
Contributor

Hello @awasthiabhijeet, it should be part of the latest release, could you recheck it?

@awasthiabhijeet
Copy link

Thanks, @pacman100 :)
Yes, hf scheduler + ds optimizer combination is working well with the latest release!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Nov 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants