Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about torch_dtype when runnging run_orpo.py #174

Closed
sylee96 opened this issue Jun 23, 2024 · 6 comments
Closed

Question about torch_dtype when runnging run_orpo.py #174

sylee96 opened this issue Jun 23, 2024 · 6 comments

Comments

@sylee96
Copy link

sylee96 commented Jun 23, 2024

I have been using run_orpo.py with my personal data successfully. However, as I use it, I have a question.

When I look at the code for run_orpo.py, I see that there is a code to match torch_dtype to the dtype of the pretrained model. However, when I actually train and save the model, even if the pretrained model's dtype was bf16, it gets changed to fp32. Why is this happening?

@alvarobartt
Copy link
Member

Hi here! Not sure if that's related to #175 at all, but feel free to upgrade the trl version and re-run as mentioned in that issue 🤗

Other than that, could you share the configuration you're using so that we can reproduce and debug that issue further? Thanks in advance!

@sylee96
Copy link
Author

sylee96 commented Jul 16, 2024

Hi alvarobartt,

Here are the details of the environment and configuration I used:

  • Python version: 3.10.12
  • PyTorch version: 2.3.1
  • transformers library version: 4.42.3
  • trl library version: 0.9.6
  • accelerate library version: 0.32.1

run_orpo.py configuration:
torch_dtype setting: bf16

With this setup, the dtype of the model changes to fp32 when saving the model, even though it was set to bf16. Please let me know if you need any additional information.

Thanks!

@alvarobartt
Copy link
Member

Thanks for that @sylee96, to better understand the problem here is that the training is indeed happening in bfloat16, but save_pretrained is storing the weights in float32 instead? How did you checked that? Could you share the command you're running as python run_orpo.py ... or accelerate launch ... run_orpo.py ... to try to reproduce on our end? Thanks again 🤗

@sylee96
Copy link
Author

sylee96 commented Jul 22, 2024

Thanks for answering, @alvarobartt.

I use this command line like this.

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file orpo/configs/fsdp.yaml orpo/run_orpo.py orpo/configs/config_full.yaml

When I checked the gemma2, llama3, or qwen2 model dtype before training, the model's dtype was set to bfloat16.
But when I checked dtype of models after training and saving, I detected the dtype of model was changed to float32.

When I checked the model's dtype, I used this line.

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
print(model.dtype)

@alvarobartt
Copy link
Member

Thanks for answering, @alvarobartt.

Anytime @sylee96!

When I checked the model's dtype, I used this line.

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
print(model.dtype)

You should load it as follows i.e. specifying the torch.dtype to use when loading the model, otherwise torch.float32 is used by default.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
print(model.dtype)

Hope that helps! 🤗

@sylee96
Copy link
Author

sylee96 commented Jul 30, 2024

Thanks for your help, @alvarobartt!

I would close this issue.

@sylee96 sylee96 closed this as completed Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants