Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pythia regression in transformers==4.36.2 vs transformers==4.30.1 #28316

Closed
2 of 4 tasks
vwxyzjn opened this issue Jan 2, 2024 · 8 comments · Fixed by #28602
Closed
2 of 4 tasks

Pythia regression in transformers==4.36.2 vs transformers==4.30.1 #28316

vwxyzjn opened this issue Jan 2, 2024 · 8 comments · Fixed by #28602

Comments

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jan 2, 2024

System Info

Happy New Year all!

  • transformers version: 4.36.2
  • Platform: Linux-5.15.0-1049-aws-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.25.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cu121 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.6.8 (cpu)
  • Jax version: 0.4.8
  • JaxLib version: 0.4.7
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes, via accelerate

Who can help?

Maybe @younesbelkada @ArthurZucker?

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Here is a minimal reproduction https://gist.github.com/vwxyzjn/e67e0bb28363e6fbb309bd0b78922a93. I ran the same repro.py with transformers==4.36.2 and transformers==4.30.1, resulting in slightly different losses. Given the data is and other dependencies are precisely the same.

python repro.py # 4.36.2
epoch: 0
update: 9, loss: 0.6855486035346985
update: 17, loss: 0.6901922225952148
update: 25, loss: 0.6883461475372314
update: 33, loss: 0.6975809931755066
update: 41, loss: 0.6995139122009277
update: 49, loss: 0.6912401914596558
update: 57, loss: 0.698995053768158
update: 65, loss: 0.7005056142807007
update: 73, loss: 0.7048475742340088
update: 81, loss: 0.6950501203536987
update: 89, loss: 0.7148610949516296
update: 97, loss: 0.694938063621521
update: 105, loss: 0.6957464814186096
update: 113, loss: 0.6873601675033569

python repro.py # 4.30.1
epoch: 0
update: 9, loss: 0.6904680132865906
update: 17, loss: 0.6958459615707397
update: 25, loss: 0.6878675818443298
update: 33, loss: 0.6945885419845581
update: 41, loss: 0.6920362710952759
update: 49, loss: 0.6866860389709473
update: 57, loss: 0.685932457447052
update: 65, loss: 0.6930047273635864
update: 73, loss: 0.6854068636894226
update: 81, loss: 0.6739884614944458
update: 89, loss: 0.6913299560546875
update: 97, loss: 0.7025052309036255

Regression in end-to-end reward model training performance

This difference causes a regression in training reward models. When setting the code, data to be exactly the same, the average reward model accuracy across four random seeds is as follows:

  • transformers==4.36.2, accelerate==0.25.0, deepspeed==0.12.6
    • EleutherAI/pythia-1b-deduped: 0.6276
    • EleutherAI/pythia-2.8b-deduped: 0.6438
    • EleutherAI/pythia-6.9b-deduped: 0.65
  • transformers==4.30.1, accelerate==0.25.0, deepspeed==0.12.6
    • EleutherAI/pythia-1b-deduped: 0.6327
    • EleutherAI/pythia-2.8b-deduped: 0.6713
    • EleutherAI/pythia-6.9b-deduped: 0.6923

The SFT losses are relatively similar (maybe except for 6.9B, there was a minor loss explosion with transformers==4.36.2)

Here is the report. https://wandb.ai/costa-huang/tldr_summarize/reports/pythia-transformers-regression--Vmlldzo2Mzk3OTQ1

image image image

Here is the code comparison: identical code and only the dependencies are different
image

image

Expected behavior

There shouldn't be a regression in the performance.

vwxyzjn added a commit to vwxyzjn/lm-human-preference-details that referenced this issue Jan 3, 2024
@ArthurZucker
Copy link
Collaborator

Sorry but the source of the regression might be pretty much anything. If the model supports SDPA, it can come from SDPA, if the tokenizer had a bug before, it might be the tokenizer etc etc
I can't debug this as is, would you mind comparing with closer transformers releases? This might help isolating this but otherwise the scope is just way too broad. The modeling code has / might have changed, the caching mechanism has changed, torch operators might have been fix etc etc

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 3, 2024

Thanks for the reply! I compared the following transformers releases and noticed that since 4.36.0, the losses become different. I also validated end-to-end that 4.33.2 is fine.

image
python repro.py # 4.30.1
epoch: 0
update: 9, loss: 0.6904680132865906
update: 17, loss: 0.6958459615707397
update: 25, loss: 0.6878675818443298
update: 33, loss: 0.6945885419845581
update: 41, loss: 0.6920362710952759
update: 49, loss: 0.6866860389709473
update: 57, loss: 0.685932457447052
update: 65, loss: 0.6930047273635864
update: 73, loss: 0.6854068636894226
update: 81, loss: 0.6739884614944458
update: 89, loss: 0.6913299560546875
update: 97, loss: 0.7025052309036255


#4.33.2
epoch: 0
update: 9, loss: 0.6904680132865906
update: 17, loss: 0.6958459615707397
update: 25, loss: 0.6878675818443298
update: 33, loss: 0.6945885419845581
update: 41, loss: 0.6920362710952759
update: 49, loss: 0.6866860389709473
update: 57, loss: 0.685932457447052
update: 65, loss: 0.6930047273635864
update: 73, loss: 0.6854068636894226
update: 81, loss: 0.6739884614944458

# 4.35.1 
===training model===
epoch: 0
update: 9, loss: 0.6904680132865906
update: 17, loss: 0.6958459615707397
update: 25, loss: 0.6878675818443298
update: 33, loss: 0.6945885419845581
update: 41, loss: 0.6920362710952759
update: 49, loss: 0.6866860389709473
update: 57, loss: 0.685932457447052
update: 65, loss: 0.6930047273635864
update: 73, loss: 0.6854068636894226
update: 81, loss: 0.6739884614944458
update: 89, loss: 0.6913299560546875


# 4.36.0
===training model===
epoch: 0
update: 9, loss: 0.6855486035346985
update: 17, loss: 0.6901922225952148
update: 25, loss: 0.6883461475372314
update: 33, loss: 0.6975809931755066
update: 41, loss: 0.6995139122009277
update: 49, loss: 0.6912401914596558
update: 57, loss: 0.698995053768158
update: 65, loss: 0.7005056142807007
update: 73, loss: 0.7048475742340088
update: 81, loss: 0.6950501203536987
update: 89, loss: 0.7148610949516296

# 4.36.1
epoch: 0
update: 9, loss: 0.6855486035346985
update: 17, loss: 0.6901922225952148
update: 25, loss: 0.6883461475372314
update: 33, loss: 0.6975809931755066
update: 41, loss: 0.6995139122009277
update: 49, loss: 0.6912401914596558
update: 57, loss: 0.698995053768158
update: 65, loss: 0.7005056142807007
update: 73, loss: 0.7048475742340088
update: 81, loss: 0.6950501203536987
update: 89, loss: 0.7148610949516296

@ArthurZucker
Copy link
Collaborator

Could you try using attn_implementation = "eager" instead of sdpa wherever you instantiate a model? One of the biggest changes from 4.36 is this! See here

@ArthurZucker
Copy link
Collaborator

Also the number you have don't really seem alarming no?

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented Jan 3, 2024

I ran it with

        self.lm_backbone = AutoModel.from_pretrained(
            config.base_model,
            config=self.config.base_config,
            trust_remote_code=True,
            attn_implementation="eager",
        )

and it did not seem to make a difference.

Also the number you have don't really seem alarming no?

Yeah, but I guess this is why it's tricky — the numbers do not look that different but it causes a significant regression for reward model training. Maybe the hidden states index are being messed up somehow? It's using self.scalar_head(output.hidden_states[-1]).

@ArthurZucker
Copy link
Collaborator

Oh sorry if using output_hidden_states, eager will by default be used.
I have no idea, pinging @pacman100 our training expert for idea and @younesbelkada for SFT training which should be more relevant expertise than me!

@younesbelkada
Copy link
Contributor

Hi @vwxyzjn !
Happy new year!
Hmm this is interesting, I don't have a clear idea either on what could be causing this, but looking at the commit history of GPTNeoX modeling code it could be :
1- Attention dropout support: 3927404
2- RoPE scaling: #24653
3- Potentially the Gradient checkpointing refactor as well #27020
If the experiments are not too long to be ran, can you try to checkout on each of these commits and see which one might be responsible of the regression?

@ArthurZucker
Copy link
Collaborator

I think it's 253f9a3 i'll fix the nans!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants