Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] The NCCL timed out while using the zero3 model. How can I solve this problem? #5066

Open
awzhgw opened this issue Feb 3, 2024 · 8 comments
Labels
bug Something isn't working training

Comments

@awzhgw
Copy link

awzhgw commented Feb 3, 2024

The NCCL timed out while using the zero3 model. How can I solve this problem?

I inherited the large model Mixtral 7BX8 and utilized the Llama architecture, augmenting it with multi-modal capabilities for video and audio.

The architecture of my model is as follows:

LlavaMixtralForCausalLM(
  (model): LlavaMixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0): MixtralDecoderLayer(
        (self_attn): MixtralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBLockSparseTop2MLP(
              (w1): Linear(in_features=4096, out_features=14336, bias=False)
              (w2): Linear(in_features=14336, out_features=4096, bias=False)
              (w3): Linear(in_features=4096, out_features=14336, bias=False)
              (act_fn): SiLU()
            )
          )
        )
        (input_layernorm): MixtralRMSNorm()
        (post_attention_layernorm): MixtralRMSNorm()
      )
    )
    (norm): MixtralRMSNorm()
    (image_tower): CLIPVisionTower(
      (image_tower): CLIPVisionModel(
        (vision_model): CLIPVisionTransformer(
          (embeddings): CLIPVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
            (position_embedding): Embedding(577, 1024)
          )
          (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (encoder): CLIPEncoder(
            (layers): ModuleList(
              (0-23): 24 x CLIPEncoderLayer(
                (self_attn): CLIPAttention(
                  (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                )
                (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (mlp): CLIPMLP(
                  (activation_fn): QuickGELUActivation()
                  (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                  (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                )
                (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              )
            )
          )
          (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (video_tower): LanguageBindVideoTower(
      (video_tower): CLIPVisionTransformer(
        (embeddings): CLIPVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
          (position_embedding): Embedding(257, 1024)
        )
        (patch_dropout): PatchDropout()
        (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder): CLIPEncoder(
          (layers): ModuleList(
            (0-23): 24 x CLIPEncoderLayer(
              (self_attn): CLIPAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
              )
              (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): CLIPMLP(
                (activation_fn): GELUActivation()
                (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                (fc2): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (temporal_attn): CLIPAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
              )
              (temporal_layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
        )
        (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (mm_projector): build_projector(
      (image_spatial_proj): Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=4096, out_features=4096, bias=True)
      )
      (video_patch_proj): Linear(in_features=1024, out_features=4096, bias=True)
    )
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

After initializing the model, I have already called deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
print('model z3_leaf_model is ',deepspeed.utils.get_z3_leaf_modules(model))

The printed result is as follows.:

model z3_leaf_model is  [MixtralSparseMoeBlock(
  (gate): Linear(in_features=4096, out_features=8, bias=False)
  (experts): ModuleList(
    (0-7): 8 x MixtralBLockSparseTop2MLP(
      (w1): Linear(in_features=4096, out_features=14336, bias=False)
      (w2): Linear(in_features=14336, out_features=4096, bias=False)
      (w3): Linear(in_features=4096, out_features=14336, bias=False)
      (act_fn): SiLU()
    )
  )
)]

This proves that the z3_leaf_model has been set up correctly.


my deepspeed version is deepspeed master branch
  1. The training process is as follows:
    Scenario 1: When I use zero3 for deepspeed training, if the training data source only contains images, there are no issues, and training can proceed safely.

    Scenario 2: When I use zero3 for deepspeed training, if the training data source contains both images and videos, it will get stuck after 270 steps, with an ongoing NCCL timeout.

The error message is as follows.

{'loss': 6.8843, 'learning_rate': 0.0009838432886246189, 'epoch': 0.03}
  3%|| 270/9847 [06:18<3:39:12,  1.37s/it]Invalidate trace cache @ step 1: expected module 15, but got module 313
[E ProcessGroupNCCL.cpp:467] [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800376 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800392 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800436 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=75264, NumelOut=602112, Timeout(ms)=1800000) ran for 1800709 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800860 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800869 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800882 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:467] [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800988 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=75264, NumelOut=602112, Timeout(ms)=1800000) ran for 1800709 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 0] NCCL watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=75264, NumelOut=602112, Timeout(ms)=1800000) ran for 1800709 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 1] NCCL watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800392 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 1] NCCL watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800392 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 7] NCCL watchdog thread terminated with exception: [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800860 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 7] NCCL watchdog thread terminated with exception: [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800860 milliseconds before timing out.
[2024-02-03 11:48:17,107] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694488
[2024-02-03 11:48:17,108] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694489
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 2] NCCL watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800436 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:852] [Rank 6] NCCL watchdog thread terminated with exception: [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800376 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 2] NCCL watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800436 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 6] NCCL watchdog thread terminated with exception: [Rank 6] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800376 milliseconds before timing out.
[2024-02-03 11:48:20,395] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694490
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:852] [Rank 3] NCCL watchdog thread terminated with exception: [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800869 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 3] NCCL watchdog thread terminated with exception: [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800869 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:852] [Rank 4] NCCL watchdog thread terminated with exception: [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800988 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 4] NCCL watchdog thread terminated with exception: [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800988 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:481] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:487] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:852] [Rank 5] NCCL watchdog thread terminated with exception: [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800882 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 5] NCCL watchdog thread terminated with exception: [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=202558, OpType=_ALLGATHER_BASE, NumelIn=2359296, NumelOut=18874368, Timeout(ms)=1800000) ran for 1800882 milliseconds before timing out.
[2024-02-03 11:48:25,872] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694491
[2024-02-03 11:48:27,258] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694492
[2024-02-03 11:48:27,261] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694493
[2024-02-03 11:48:27,263] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694494
[2024-02-03 11:48:27,265] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 2694495
[2024-02-03 11:48:27,267] [ERROR] [launch.py:322:sigkill_handler] ['/usr/bin/python', '-u', 'moellava/train/train_mem.py', '--local_rank=7', '--deepspeed', './scripts/zero3_offload.json', '--model_name_or_path', '/export/App/training_platform/PinoModel/mixtral/Mixtral-8x7B-Instruct-v0.1', '--version', 'mixtral', '--data_path', '/mnt/moe/moe/dataset/data_root/train_json/pretrain/valley_llavaimage.json', '--image_folder', '/mnt/moe/moe/dataset/data_root', '--image_tower', '/export/App/training_platform/PinoModel/openai/clip-vit-large-patch14-336', '--image_projector_type', 'mlp2x_gelu', '--video_tower', '/export/App/training_platform/PinoModel/LanguageBind/LanguageBind_Video_merge', '--video_folder', '/mnt/moe/moe/dataset/data_root', '--tune_mm_mlp_adapter', 'True', '--mm_vision_select_layer', '-2', '--mm_use_im_start_end', 'False', '--mm_use_im_patch_token', 'False', '--bf16', 'True', '--output_dir', './checkpoints/llavamixtral-7b-pretrain', '--num_train_epochs', '1', '--per_device_train_batch_size', '16', '--per_device_eval_batch_size', '4', '--gradient_accumulation_steps', '1', '--evaluation_strategy', 'no', '--save_strategy', 'steps', '--save_steps', '2400', '--save_total_limit', '1', '--learning_rate', '1e-3', '--weight_decay', '0.', '--warmup_ratio', '0.03', '--lr_scheduler_type', 'cosine', '--logging_steps', '1', '--tf32', 'True', '--model_max_length', '2048', '--gradient_checkpointing', 'True', '--dataloader_num_workers', '8', '--lazy_preprocess', 'True', '--report_to', 'tensorboard', '--cache_dir', './cache_dir'] exits with return code = -6

During the period when NCCL got stuck, I obtained the point at which the Python process became stuck.:

root@A03-R40-I16-12-8000045:/export/App/training_platform/PinoModel# py-spy dump -p 3261644
Process 3261644: /usr/bin/python -u moellava/train/train_mem.py --local_rank=5 --deepspeed ./scripts/zero3_offload.json --model_name_or_path /export/App/training_platform/PinoModel/mixtral/Mixtral-8x7B-Instruct-v0.1 --version mixtral --data_path /mnt/moe/moe/dataset/data_root/train_json/pretrain/valley_llavaimage.json --image_folder /mnt/moe/moe/dataset/data_root --image_tower /export/App/training_platform/PinoModel/openai/clip-vit-large-patch14-336 --image_projector_type mlp2x_gelu --video_tower /export/App/training_platform/PinoModel/LanguageBind/LanguageBind_Video_merge --video_folder /mnt/moe/moe/dataset/data_root --tune_mm_mlp_adapter True --mm_vision_select_layer -2 --mm_use_im_start_end False --mm_use_im_patch_token False --bf16 True --output_dir ./checkpoints/llavamixtral-7b-pretrain --num_train_epochs 1 --per_device_train_batch_size 16 --per_device_eval_batch_size 4 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 2400 --save_total_limit 1 --learning_rate 1e-3 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type cosine --logging_steps 1 --tf32 True --model_max_length 2048 --gradient_checkpointing True --dataloader_num_workers 8 --lazy_preprocess True --report_to tensorboard --cache_dir ./cache_dir
Python v3.10.12 (/usr/bin/python3.10)

Thread 3261644 (active): "MainThread"
    <listcomp> (deepspeed/runtime/zero/partition_parameters.py:1138)
    _all_gather_dtype (deepspeed/runtime/zero/partition_parameters.py:1138)
    all_gather_coalesced (deepspeed/runtime/zero/partition_parameters.py:1252)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    __all_gather_params_ (deepspeed/runtime/zero/partitioned_param_coordinator.py:458)
    __all_gather_params (deepspeed/runtime/zero/partitioned_param_coordinator.py:429)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    fetch_sub_module (deepspeed/runtime/zero/partitioned_param_coordinator.py:380)
    decorate_context (torch/utils/_contextlib.py:115)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    pre_sub_module_forward_function (deepspeed/runtime/zero/parameter_offload.py:452)
    decorate_context (torch/utils/_contextlib.py:115)
    _pre_forward_module_hook (deepspeed/runtime/zero/parameter_offload.py:340)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _call_impl (torch/nn/modules/module.py:1557)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (transformers/models/clip/modeling_clip.py:263)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (transformers/models/clip/modeling_clip.py:372)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (torch/utils/checkpoint.py:230)
    apply (torch/autograd/function.py:539)
    checkpoint (torch/utils/checkpoint.py:450)
    inner (torch/_dynamo/external_utils.py:17)
    _fn (torch/_dynamo/eval_frame.py:333)
    inner (torch/_compile.py:24)
    forward (transformers/models/clip/modeling_clip.py:622)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (transformers/models/clip/modeling_clip.py:844)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (transformers/models/clip/modeling_clip.py:917)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (clip_encoder.py:50)
    decorate_context (torch/utils/_contextlib.py:115)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    encode_images (moellava/model/llava_arch.py:152)
    prepare_inputs_labels_for_multimodal (moellava/model/llava_arch.py:198)
    forward (llava_mixtral.py:83)
    _call_impl (torch/nn/modules/module.py:1568)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    forward (deepspeed/runtime/engine.py:1842)
    wrapped_fn (deepspeed/utils/nvtx.py:15)
    _call_impl (torch/nn/modules/module.py:1527)
    _wrapped_call_impl (torch/nn/modules/module.py:1518)
    compute_loss (transformers/trainer.py:2795)
    training_step (transformers/trainer.py:2772)
    _inner_training_loop (transformers/trainer.py:1868)
    train (transformers/trainer.py:1539)
    train (train.py:1475)
    <module> (train_mem.py:13)
Thread 3262753 (idle): "Thread-1"
    select (selectors.py:416)
    wait (multiprocessing/connection.py:931)
    wait_result_broken_or_wakeup (concurrent/futures/process.py:385)
    run (concurrent/futures/process.py:320)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3264158 (idle): "Thread-2"
    wait (threading.py:324)
    wait (threading.py:607)
    run (tqdm/_monitor.py:60)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3267395 (idle): "Thread-3 (_pin_memory_loop)"
    select (selectors.py:416)
    wait (multiprocessing/connection.py:931)
    _poll (multiprocessing/connection.py:424)
    poll (multiprocessing/connection.py:257)
    get (multiprocessing/queues.py:113)
    do_one_step (torch/utils/data/_utils/pin_memory.py:31)
    _pin_memory_loop (torch/utils/data/_utils/pin_memory.py:54)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268088 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268152 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268153 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268154 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268155 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268156 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268157 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3268158 (idle): "QueueFeederThread"
    wait (threading.py:320)
    _feed (multiprocessing/queues.py:231)
    run (threading.py:953)
    _bootstrap_inner (threading.py:1016)
    _bootstrap (threading.py:973)
Thread 3303923 (idle)
Thread 3303931 (idle)
Thread 3303916 (idle)
Thread 3303934 (idle)
Thread 3303942 (idle)
Thread 3303945 (idle)
Thread 3303952 (idle)
Thread 3303949 (idle)
@awzhgw awzhgw added bug Something isn't working training labels Feb 3, 2024
@awzhgw
Copy link
Author

awzhgw commented Feb 3, 2024

@tohtana can you help me ???

@hanxiaotian
Copy link

same here

1 similar comment
@jingcangcang
Copy link

same here

@hanxiaotian
Copy link

Found the potential cause, some experts during training don't see any token, thus no gradients, all other processes will get stucked. After feed fake gradient to experts that don't see any token, the training goes smooth.

@WanqiZhong
Copy link

Found the potential cause, some experts during training don't see any token, thus no gradients, all other processes will get stucked. After feed fake gradient to experts that don't see any token, the training goes smooth.找到了潜在的原因,一些专家在训练期间没有看到任何标记,因此没有梯度,所有其他过程都会被卡住。将假梯度提供给看不到任何标记的专家后,训练就会顺利进行。

Could you please provide an example on how to feed fake gradient to experts? Much appreciated! @hanxiaotian

@hanxiaotian
Copy link

Found the potential cause, some experts during training don't see any token, thus no gradients, all other processes will get stucked. After feed fake gradient to experts that don't see any token, the training goes smooth.找到了潜在的原因,一些专家在训练期间没有看到任何标记,因此没有梯度,所有其他过程都会被卡住。将假梯度提供给看不到任何标记的专家后,训练就会顺利进行。

Could you please provide an example on how to feed fake gradient to experts? Much appreciated! @hanxiaotian

something like below modification in HF Mixtral implementation

    for expert_idx in range(self.num_experts):
        expert_layer = self.experts[expert_idx]
        idx, top_x = torch.where(expert_mask[expert_idx])

        if top_x.shape[0] == 0 and self.training:
            if self.training:
                top_x_ = torch.zeros(1).to(hidden_states.device).to(torch.int32)
                top_x_list = top_x_.tolist()
                current_state = hidden_states[None, top_x_list].reshape(
                    -1, hidden_dim
                )
                fake_state = expert_layer(current_state * 0)
                final_hidden_states.index_add_(
                    0, top_x_, fake_state.to(hidden_states.dtype)
                )
            else:
                continue
        else:
            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            current_hidden_states = (
                expert_layer(current_state)
                * routing_weights[top_x_list, idx_list, None]
            )

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(
                0, top_x, current_hidden_states.to(hidden_states.dtype)
            )

Hope this can help.

@QAQdev
Copy link

QAQdev commented Nov 19, 2024

Found the potential cause, some experts during training don't see any token, thus no gradients, all other processes will get stucked. After feed fake gradient to experts that don't see any token, the training goes smooth.找到了潜在的原因,一些专家在训练期间没有看到任何标记,因此没有梯度,所有其他过程都会被卡住。将假梯度提供给看不到任何标记的专家后,训练就会顺利进行。

This comment is very very gorgeous! God bless you!

@tohtana
Copy link
Contributor

tohtana commented Nov 25, 2024

Hi @awzhgw @QAQdev, we are considering how we could fix this issue with leaf module. Can you share the repro if possible?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

6 participants