-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RoPE: model-agnostic RoPE refactor #31999
base: main
Are you sure you want to change the base?
Conversation
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts Interpolation and Attention Scaling methods, improving upon existing RoPE interpolation methods for longer context window sizes. Fine-tuned models maintain their original performance across benchmarks while enabling efficient extrapolation and transfer learning for quicker convergence, especially in compute-limited environments. We implement YaRN and Dynamic-YaRN for the following list of models: - LLaMA - Falcon - GPT-NeoX - Olmo - Persimmon - Phi - StableLM - OpenLLaMA New unit tests are added to assert YaRN's correct behavior on both short and long sequence inputs. For more details, please refer to https://arxiv.org/abs/2309.00071. Co-authored-by: Miguel Almeida <[email protected]>
Iterate on YaRN implementation for LLaMA and remove diff from remaining models for increased PR modularity. This commit includes the following changes: - Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries - Remove unnecessary attributes ('extrapolation_factor' and 'finetuned') from YaRN classes - Inherit 'forward' method in YaRN classes from superclass - Rename 'yarn' method to 'compute_yarn_scaling' - Extend YaRN tests with further assertions - Fix style inconsistencies Co-authored-by: Miguel Monte e Freitas <[email protected]>
- Comply with the the tensor building logic introduced in huggingface#30743 - Add referencing to the optimized Attention Factor equation - Remove Dynamic YaRN for a more agile deployment Co-authored-by: mig-mfreitas <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of to a good start.
We want this to be easily configurable IMO, and with the least amount of checks on our side!
cos = cos * self.rope_config["attention_factor"] | ||
sin = sin * self.rope_config["attention_factor"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this lives in a config vs in a tensor or buffer we will have device issue + we have less freedom IMO and no idea about the dtype
config = LlamaConfig(**kwargs) | ||
config.rope_theta = base | ||
config.max_position_embeddings = max_position_embeddings | ||
config.head_dim = dim # this one doesn't actually exist, will only be used in the deprecation transition | ||
if scaling_factor == 1.0 and len(kwargs) == 0: | ||
config.rope_scaling = None | ||
else: | ||
config.rope_scaling = {"type": "default", "factor": scaling_factor} | ||
config.rope_scaling |= kwargs # may overwrite "type" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's fairly weird (init a config) but only happens once, should be alright
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the easiest path for the deprecation: in v4.45 we just delete these lines 👼
(all RoPE models with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just gave a quick look at the API which looks good to me. Very nice and clean changes with the deprecation cycle.
Thanks for iterating on the PR! (Would really like to have @amyeroberts take a look at the PR as well if possible)
I'm trying to train the Will the PR fix this issue? If yes, when can we expect this to merge in main? |
MMmm what's weird is that this model uses code on the hub. |
Way to Reproduce:
|
That model is "code on the hub" so it's kind of expected |
Note: splitting this PR into multiple smaller ones, as the refactor needs extra attention in some models (e.g. Keeping the PR open as a reference, until all models have the new RoPE structure |
What does this PR do?
This PR:
longrope
, as part of the model-agnostic refactor on Phi3 (closes Plans to Integrate LongRoPE into LLaMA? #31992); Withlongrope
, phi3's checkpoints are now loadable.👉 Built on top of the Yarn PR (#30910)
Review
Key files to check, IN THIS SPECIFIC ORDER:
👉 Other relevant files include
phi3
(longrope
) andrecurrentgemma
(a few custom changes)Models that require future changes for standardization
cache_positions
, and therefore they are not changed as part of this PR (the new classes is built with the new pattern in mind). A future PR is needed on these models, where bothcache_positions
and this new model-agnostic RoPE is added.Models that were NOT changed but have RoPE: