-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training Loss inconsistent after resume from old checkpoint #25340
Comments
Hi @dumpmemory, thanks for raising this issue! So that we can best try and help could you:
|
flash attention patchfrom typing import List, Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def forward_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
assert past_key_value is None, "past_key_value is not supported"
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if self.training:
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, dropout_p=0.0, is_causal=True
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def replace_llama_attn_with_flash_attn():
if hasattr(F, "scaled_dot_product_attention"):
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2
else:
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward and i have add this patch to run_clm.py example following fastchat style ##run_clm.py python -m torch.distributed.run --nproc_per_node=8 --nnode=${num_node} --node_rank=${machine_rank} --master_addr=xxx \
--master_port=9901 run_clm.py \
--model_name_or_path llama-7b \
--dataset_name nRuaif/lightnovel-2048 \
--learning_rate 1e-5 --per_device_train_batch_size 16 --gradient_accumulation_steps 2 \
--per_device_eval_batch_size 16 --num_train_epochs 2 \
--warmup_steps 5000 --preprocessing_num_workers 16 \
--report_to "tensorboard" --weight_decay 0.1 \
--output_dir "xxx" --lora_r 16 --block_size 2048 \
--bf16 --bf16_full_eval --run_name llama_test \
--do_train --seed 42 --data_seed 42 \
--log_on_each_node false --max_grad_norm 0.7 \
--dataloader_num_workers 16 \
--ddp_timeout 108000 \
--use_fast_tokenizer false \
--save_steps 128 --logging_steps 10 --gradient_checkpointing true --deepspeed ds_config_zero3.json ds_config_zero3.json{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 1e-6,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 1,
"wall_clock_breakdown": true,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}
I had run run_clm.py twice |
@dumpmemory Thanks for providing this. To help us debug, could you providing information about the number of nodes and machine rank used when launching the script? Have you tried running this without the patch for flash attention? Note, it's possible to apply flash attention to a transformers model by install
cc @pacman100 as it involves deepspeed. |
I have run it with 2*8 h100 gpus. |
It seems that if checkpoint resume from epoch 0, the loss is consistent. but if resume from epoch 1.2, the loss start to be strange. |
@muellerzr is something related to huggingface/accelerate#1466 ? |
@amyeroberts can u help me out ? |
from 4.29.2
is this part related to this issue ?
|
change
to
fix this issue. |
Nice find @dumpmemory |
Would you like to open a PR with this solution? Otherwise I'll get to it today or tommorow :) |
yeah, i have tried to add a pr with code logic from 4.29.2 for this issue |
#25872) * fix loss inconsistent after resume #25340 * fix typo * clean code * reformatted code * adjust code according to comments * adjust check_dataloader_randomsampler location * return sampler only * handle sampler is None * Update src/transformers/trainer_pt_utils.py thanks @amyeroberts Co-authored-by: amyeroberts <[email protected]> --------- Co-authored-by: amyeroberts <[email protected]>
huggingface#25872) * fix loss inconsistent after resume huggingface#25340 * fix typo * clean code * reformatted code * adjust code according to comments * adjust check_dataloader_randomsampler location * return sampler only * handle sampler is None * Update src/transformers/trainer_pt_utils.py thanks @amyeroberts Co-authored-by: amyeroberts <[email protected]> --------- Co-authored-by: amyeroberts <[email protected]>
System Info
transformers
version: 4.31.0Platform: Linux-5.4.119-19.0009.28-x86_64-with-glibc2.35
Python version: 3.10.6
Huggingface_hub version: 0.15.1
Safetensors version: 0.3.1
Accelerate version: 0.21.0
Accelerate config: not found
PyTorch version (GPU?): 2.0.0 (True)
Tensorflow version (GPU?): not installed (NA)
Flax version (CPU?/GPU?/TPU?): not installed (NA)
Jax version: not installed
JaxLib version: not installed
Using GPU in script?:
Using distributed or parallel set-up in script?:
Accelerate
version: 0.21.0Platform: Linux-5.4.119-19.0009.28-x86_64-with-glibc2.35
Python version: 3.10.6
Numpy version: 1.22.2
PyTorch version (GPU?): 2.0.0 (True)
PyTorch XPU available: False
PyTorch NPU available: False
System RAM: 1877.62 GB
GPU type: NVIDIA H800
Accelerate
default config:Not found
DeepSpeed C++/CUDA extension op report
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
JIT compiled ops requires ninja
ninja .................. [OKAY]
op name ................ installed .. compatible
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
[WARNING] using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 2.0.0
deepspeed install path ........... ['/usr/local/lib/python3.10/dist-packages/deepspeed']
deepspeed info ................... 0.9.5, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.0, cuda 12.1
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Training loss should be the same level before and after resume.
The text was updated successfully, but these errors were encountered: