-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] Change the condition of ValueError in "convert_checkpoint_from_transformers_to_megatron" #24769
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
norm_factor is still torch.float32 after using model.half So I changed it to register_buffer so I can change it to torch.float16 after using model.half
convert_checkpoint_from_transformers_to_megatron
layers -> attention heads
cc @pacman100 |
pacman100
approved these changes
Jul 13, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! thank you for the fix
amyeroberts
approved these changes
Jul 13, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
The documentation is not available anymore as the PR was closed or merged. |
Lorenzobattistela
pushed a commit
to Lorenzobattistela/transformers
that referenced
this pull request
Jul 13, 2023
…transformers_to_megatron" (huggingface#24769) * fix: half inference error norm_factor is still torch.float32 after using model.half So I changed it to register_buffer so I can change it to torch.float16 after using model.half * fix: Added a variable "persistent=False" * run make style * [fix] Change the condition of ValueError convert_checkpoint_from_transformers_to_megatron * [fix] error wording layers -> attention heads
blbadger
pushed a commit
to blbadger/transformers
that referenced
this pull request
Nov 8, 2023
…transformers_to_megatron" (huggingface#24769) * fix: half inference error norm_factor is still torch.float32 after using model.half So I changed it to register_buffer so I can change it to torch.float16 after using model.half * fix: Added a variable "persistent=False" * run make style * [fix] Change the condition of ValueError convert_checkpoint_from_transformers_to_megatron * [fix] error wording layers -> attention heads
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The "target_tensor_model_parallel_size" is related to "num_attention_heads", and the "target_pipeline_model_parallel_size" is related to "num_hidden_layers".
However, the old code had "target_tensor_model_parallel_size" related to "num_hidden_layers".
So we modified the code and added the part about "target_tensor_model_parallel_size".
Thanks!