Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainer resume from checkpoint,the learning rate is not the same as retraining,learning rate is discontinuous #34053

Open
2 of 4 tasks
LBJ6666 opened this issue Oct 10, 2024 · 19 comments

Comments

@LBJ6666
Copy link

LBJ6666 commented Oct 10, 2024

System Info

  • Platform: Windows-10
  • transformers version: 4.43.4
  • Python version: 3.10.11
  • PyTorch version (GPU?): 2.3.1+cu121

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

the Trainer does not set a warmup and the lr_scheduler is set to linear, and the training is continued from an interruption to complete all steps, the learning rates will be different from those when training all steps from the beginning. Here are the specific learning rates:

Learning rates for training from the beginning for each step:

  • Step 1: "learning_rate": 1e-05,
  • Step 2: "learning_rate": 1e-05,
  • Step 3: "learning_rate": 9e-06,
  • Step 4: "learning_rate": 8.000000000000001e-06,
  • Step 5: "learning_rate": 7e-06,
  • Step 6: "learning_rate": 6e-06,
  • Step 7: "learning_rate": 5e-06,
  • Step 8: "learning_rate": 4.000000000000001e-06,
  • Step 9: "learning_rate": 3e-06,
  • Step 10: "learning_rate": 2.0000.

If training is continued from a checkpoint at step 5, the learning rates for each step are:

  • Step 6: "learning_rate": 7e-06,
  • Step 7: "learning_rate": 7e-06,
  • Step 8: "learning_rate": 6e-06,
  • Step 9: "learning_rate": 5e-06,
  • Step 10: "learning_rate": 4.000000000000001e-06.

Why are the learning rates for step 6 and step 7 different when training continues from a checkpoint compared to training from the start?

Reproduction steps:

  1. Train from the beginning for 10 steps, save a checkpoint for each step, and record the learning rate in each step.
  2. Delete the checkpoints for steps 6 through 7 in the folder.
  3. Then use trainer.train(resume_from_checkpoint=True) to continue training from step 5, and after training is completed, record the learning rate in the new checkpoint.

Expected behavior

Please explain why the learning rate is not continuous as it is when training from the beginning, for example:
Step 6: "learning_rate": 6e-06,
Step 7: "learning_rate": 5e-06.
........

@LysandreJik
Copy link
Member

cc @muellerzr @SunMarc

@LysandreJik LysandreJik reopened this Nov 19, 2024
@huggingface huggingface deleted a comment from github-actions bot Nov 19, 2024
@LysandreJik
Copy link
Member

I'm adding a "Good First Issue" tag; the team is under tight bandwidth at the moment so any PR to help solve this is welcome.

@SunMarc
Copy link
Member

SunMarc commented Nov 19, 2024

Hey @LBJ6666, could you please share a minimal reproducer, so that we can quickly find the issue ? Thanks !

@LBJ6666
Copy link
Author

LBJ6666 commented Nov 20, 2024

@SunMarc ,Thank for response,I wrote a simple example.

import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import numpy as np

def tokenize_function(examples):
    model_inputs = tokenizer(examples["prompt"], padding="max_length", truncation=True)
    model_inputs["labels"] = np.array(model_inputs['input_ids']).astype(np.int64)
    return model_inputs

dataset = load_dataset("fka/awesome-chatgpt-prompts")

model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenized_datasets = dataset.map(tokenize_function, batched=True)

output_dir = "./output"
training_args = TrainingArguments(
    output_dir=output_dir,
    fp16=True,
    max_steps=20,
    per_device_train_batch_size=4,
    learning_rate=1e-5,
    lr_scheduler_type="cosine",
    logging_dir=os.path.join(output_dir, "log"),
    logging_strategy="steps",
    logging_steps=1,
    report_to="tensorboard",
    save_strategy='steps',
    save_steps=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
)

trainer.train(
    resume_from_checkpoint=False
)
trainer.save_state()

Run the training and record the learning rates for 20 steps as saved in trainer_state.json within checkpoint-20.
then, delete checkpoints 16-20 to simulate an interrupted training process.
set resume_from_checkpoint=True to resume training from step 16, and compare the learning rates in the final trainer_state.json after completing the training.

or
see tensorboard log

@Knight7561
Copy link

Verify my steps of Reproduction:
For the sample example above:

  1. Case A: Ran trainer for 20 epochs and saved JSON log.
  2. Case B: Deleted checkpoints -15 and 16, and reran the trainer from checkpoint-14.
    However, I found the LR rate was the same in both cases. Please review the screenshot and let me know if I am missing any step in reproduction.

Screenshot 2024-11-21 at 12 09 35

@SunMarc
Copy link
Member

SunMarc commented Nov 21, 2024

Thanks for testing this @Knight7561 ! Could you check @LBJ6666 ?

@LBJ6666
Copy link
Author

LBJ6666 commented Nov 22, 2024

Thank! @Knight7561 @SunMarc
After breakpoint debugging, it seems the issue was caused by gradient overflow during the first few steps of training when using FP16.
This led to the optimization steps being skipped,the learning rate scheduler was also skipped, causing the learning rate to repeat the previous step's value.
At this position:

optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
    # Delay optimizer scheduling until metrics are generated
    if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        self.lr_scheduler.step()

@SunMarc
Copy link
Member

SunMarc commented Nov 22, 2024

Is there anything that could have been added to make this easier to spot ? Maybe some logs ?

@LBJ6666
Copy link
Author

LBJ6666 commented Nov 23, 2024

I think the issue lies with gradient overflow in FP16. After switching to BF16, the optimization process worked normally.
Logs from FP16=True showing the issue (excerpt from trainer_state.json):

"log_history": [
{  "epoch": 0.011764705882352941,  "grad_norm": Infinity,  "learning_rate": 1e-05,  "loss": 9.3054,  "step": 1},
{  "epoch": 0.023529411764705882,  "grad_norm": Infinity,  "learning_rate": 1e-05,  "loss": 9.5318,  "step": 2},
{  "epoch": 0.03529411764705882,  "grad_norm": Infinity,  "learning_rate": 1e-05,  "loss": 10.0388,  "step": 3},
{  "epoch": 0.047058823529411764,  "grad_norm": Infinity,  "learning_rate": 1e-05,  "loss": 9.9219,  "step": 4},
{  "epoch": 0.058823529411764705,  "grad_norm": Infinity,  "learning_rate": 1e-05,  "loss": 9.9054,  "step": 5},
{  "epoch": 0.07058823529411765,  "grad_norm": Infinity,  "learning_rate": 1e-05,  "loss": 10.3334,  "step": 6},
{  "epoch": 0.08235294117647059,  "grad_norm": Infinity,  "learning_rate": 1e-05,  "loss": 9.2773,  "step": 7},
{  "epoch": 0.09411764705882353,  "grad_norm": Infinity,  "learning_rate": 1e-05,  "loss": 9.814,  "step": 8},
{  "epoch": 0.10588235294117647,  "grad_norm": 211.26919555664062,  "learning_rate": 9.755282581475769e-06,  "loss": 9.4853,  "step": 9},
{  "epoch": 0.11764705882352941,  "grad_norm": 222.30563354492188,  "learning_rate": 9.045084971874738e-06,  "loss": 8.9275,  "step": 10},
]

Logs from BF16=True (normal behavior):

"log_history": [
{  "epoch": 0.011764705882352941,  "grad_norm": 205.610107421875,  "learning_rate": 9.755282581475769e-06,  "loss": 9.5877,  "step": 1},
{  "epoch": 0.023529411764705882,  "grad_norm": 219.57696533203125,  "learning_rate": 9.045084971874738e-06,  "loss": 8.2459,  "step": 2},
{  "epoch": 0.03529411764705882,  "grad_norm": 224.68093872070312,  "learning_rate": 7.938926261462366e-06,  "loss": 7.8228,  "step": 3},
{  "epoch": 0.047058823529411764,  "grad_norm": 206.76161193847656,  "learning_rate": 6.545084971874738e-06,  "loss": 6.0532,  "step": 4},
{  "epoch": 0.058823529411764705,  "grad_norm": 201.87257385253906,  "learning_rate": 5e-06,  "loss": 4.9474,  "step": 5},
{  "epoch": 0.07058823529411765,  "grad_norm": 197.5258331298828,  "learning_rate": 3.4549150281252635e-06,  "loss": 4.6874,  "step": 6},
{  "epoch": 0.08235294117647059,  "grad_norm": 189.73866271972656,  "learning_rate": 2.061073738537635e-06,  "loss": 4.0804,  "step": 7},
{  "epoch": 0.09411764705882353,  "grad_norm": 189.0591583251953,  "learning_rate": 9.549150281252633e-07,  "loss": 4.0516,  "step": 8},
{  "epoch": 0.10588235294117647,  "grad_norm": 183.7292022705078,  "learning_rate": 2.447174185242324e-07,  "loss": 3.732,  "step": 9},
{  "epoch": 0.11764705882352941,  "grad_norm": 191.3872528076172,  "learning_rate": 0.0,  "loss": 3.6382,  "step": 10},
],

Thus, when continuing training with FP16, if gradient overflow occurs, the learning rate scheduler is skipped.

@hsilva664
Copy link

Hello, I tried reproducing the issue with fp16 and the undesired behaviour was present. Code is as posted before, except for a new part I've added for setting the seed equal to zero. I tried first running the 20 steps while saving results, then deleting steps 12 onwards, switching the flag resume_from_checkpoint to True and running again.

As reported by @LBJ6666 , when there is gradient overflow, the optimizer does not run and, as a consequence, the LR scheduler does not get updated. However, when reproducing, the issue was a little worse than that in my case. When running the 20 steps without interruption, gradient overflow did not occur on my machine. However, it did occur when I loaded checkpoint 11.

After some inspection, I noticed that there was an object from a class called GradScaler which was an attribute of another object from the Accelerator class. When running from scratch, the scaler multiplies the loss from the step 11 model by 256 before backpropagating. However, when loading from file, the scaler actually multiplies it by 65536 instead, which causes the overflow. To solve it, I tried adding a function to save and load the scaler as well, similarly to how the scheduler and optimizer are saved and loaded. It seems to work more consistently.

Still, there are some nuances to saving and reloading that I'm not sure how to handle, such as when using Deepspeed, Sagemaker or XLA. I just copied what was done for the scheduler for now and left the code as a draft so others can look at and say how to best change it (the structure is a little unnecessarily duplicated, but this was, again, just a draft to see if it would work). Also, when loading the GradientScaler, the performance between loaded and from-scratch models gets very similar, but it does nor become exactly equal, as I think it should. I've opened the pull request at #34932 (comment)

@SunMarc
Copy link
Member

SunMarc commented Nov 26, 2024

After some inspection, I noticed that there was an object from a class called GradScaler which was an attribute of another object from the Accelerator class. When running from scratch, the scaler multiplies the loss from the step 11 model by 256 before backpropagating. However, when loading from file, the scaler actually multiplies it by 65536 instead, which causes the overflow. To solve it, I tried adding a function to save and load the scaler as well, similarly to how the scheduler and optimizer are saved and loaded. It seems to work more consistently.

Thanks for the nice investigation. This looks indeed like a bug. cc @muellerzr

Could you check @Knight7561 and @LBJ6666 that this is potentially the issue ?

@Knight7561
Copy link

Knight7561 commented Nov 26, 2024

Sure @SunMarc , let me debug this and try to reproduce the above case scenario to see what's going on.

Before I dig in to the issue, these are the test results I got, and I guess It matches with the excerpts above.
Screenshot 2024-11-26 at 13 20 31

I am just making sure my steps of reproduction is right. Can you confirm @hsilva664 @LBJ6666

@hsilva664
Copy link

My results are close to yours, slightly different because of the seed. I set the seed to zero at the very beginning with:

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

These are my results when running uninterrupted (I got them from the trainer_state.json):

image

They look similar to yours. Below are the results when running before I modified the script when I run starting from the model saved at step 11. The overflow at step 11 is because of the self.accelerator.scaler._scale parameter, as explained previously.

image

Below are the results after the draft modifications from my pull request, where I save and load the scaler as well. These results also start from the checkpoint of step 11.

image

They are closer to the initial result. Still, one thing that caught my attention is that they are close, but not equal. As the seed is set to be the same and also the saving/reloaded from transformers also seems to save and load the rng state, and both runs were done on the same machine, I believe they should have been exactly the same, instead of just similar. Perhaps there is another issue.

@Knight7561
Copy link

Thanks for the results @hsilva664. So, if I understand corrected, once you ran uninterrupted(got no inf except the starting epochs) and then you deleted the checkpoint-11 and then you rerank with resume_from_checkpoint to True, which lead to inf again (which is unexpected behaviour after cp11), right? and then when you fixed the scaler issue in your PR, thats when you get proper results as in Screenshot 3?

Am I understanding correctly? @hsilva664

@hsilva664
Copy link

I did not delete checkpoint 11, only checkpoints 12 onwards. Other than this, your explanation is correct.

@LBJ6666
Copy link
Author

LBJ6666 commented Nov 27, 2024

@hsilva664 Thank you for the further testing.

I have successfully tested the code on the PR #34932 , and the results show that the logs for first running the 20 steps of training and the logs for resuming training from checkpoint 11 (train steps 12-20) are identical, including the loss, gradient, and learning rate.

Regarding result reproducibility, the slight differences observed between two runs on the same machine are due to non-determinism. To ensure consistent results, the following configuration changes can reproduce the results:

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
import numpy as np
import random
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True)
set_seed(0)

@Knight7561
Copy link

So, @LBJ6666, Does the PR solves your Issue?

@hsilva664
Copy link

Hello, by previous discussions, it seems so. Still, the PR is not ready to be merged. As per request, I tried adding some tests that would detect the save/reloading errors better than the tests that were already there. However, these new tests required setting deterministic behaviour differently, which can be done via flags and environment variables. My new tests generate consistent results when run in isolation or when run together with the tests from class TrainerIntegrationTest, but they do not pass when all tests from the containing file are run together.

I assume this might be because another test, from another class but same file, running on some parallel worker somehow changes the global state in a way that the deterministic behaviour I had set stopped working. Since I'm not sure how to proceed, I left it here until somebody takes a look. More details are in the PR discussion.

@LBJ6666
Copy link
Author

LBJ6666 commented Dec 11, 2024

@Knight7561 ,Yes, after testing my code, it resolved my issue. I look forward to more comprehensive testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants