Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE: model-agnostic RoPE refactor #31999

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open

Conversation

gante
Copy link
Member

@gante gante commented Jul 16, 2024

What does this PR do?

This PR:

  • Refators RoPE such that it is model-agnostic.
    • RoPE models now only need one class
    • The class is parameterized by the model config.
    • Based on the model config, the appropriate type of rope will be loaded into the class
  • Adds longrope, as part of the model-agnostic refactor on Phi3 (closes Plans to Integrate LongRoPE into LLaMA? #31992); With longrope, phi3's checkpoints are now loadable.

👉 Built on top of the Yarn PR (#30910)


Review

Key files to check, IN THIS SPECIFIC ORDER:

src/transformers/models/llama/modeling_llama.py 
src/transformers/models/llama/configuration_llama.py
src/transformers/modeling_rope_utils.py

👉 Other relevant files include phi3 (longrope) and recurrentgemma (a few custom changes)


Models that require future changes for standardization

⚠️ Some models don't support cache_positions, and therefore they are not changed as part of this PR (the new classes is built with the new pattern in mind). A future PR is needed on these models, where both cache_positions and this new model-agnostic RoPE is added.

Models that were NOT changed but have RoPE:

  • ESM
  • Falcon
  • GPTNeoX
  • GPTNeoXJapanese
  • Idefics
  • Mixtral
  • Persimmon
  • Phi
  • Qwen2
  • Qwen2MoE
  • StableLM
  • Starcoder2

mig-mfreitas and others added 12 commits May 20, 2024 10:21
YaRN (Yet another RoPE extension method) combines the NTK-By-Parts
Interpolation and Attention Scaling methods, improving upon existing
RoPE interpolation methods for longer context window sizes.

Fine-tuned models maintain their original performance across benchmarks
while enabling efficient extrapolation and transfer learning for
quicker convergence, especially in compute-limited environments.

We implement YaRN and Dynamic-YaRN for the following list of models:

 - LLaMA
 - Falcon
 - GPT-NeoX
 - Olmo
 - Persimmon
 - Phi
 - StableLM
 - OpenLLaMA

New unit tests are added to assert YaRN's correct behavior on both
short and long sequence inputs.

For more details, please refer to https://arxiv.org/abs/2309.00071.

Co-authored-by: Miguel Almeida <[email protected]>
Iterate on YaRN implementation for LLaMA and remove diff from remaining
models for increased PR modularity.

This commit includes the following changes:
- Merge 'yarn_rope_scaling' and 'rope_scaling' dictionaries
- Remove unnecessary attributes ('extrapolation_factor' and 'finetuned')
  from YaRN classes
- Inherit 'forward' method in YaRN classes from superclass
- Rename 'yarn' method to 'compute_yarn_scaling'
- Extend YaRN tests with further assertions
- Fix style inconsistencies

Co-authored-by: Miguel Monte e Freitas <[email protected]>
- Comply with the the tensor building logic introduced in huggingface#30743
- Add referencing to the optimized Attention Factor equation
- Remove Dynamic YaRN for a more agile deployment

Co-authored-by: mig-mfreitas <[email protected]>
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of to a good start.
We want this to be easily configurable IMO, and with the least amount of checks on our side!

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
Comment on lines 150 to 151
cos = cos * self.rope_config["attention_factor"]
sin = sin * self.rope_config["attention_factor"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this lives in a config vs in a tensor or buffer we will have device issue + we have less freedom IMO and no idea about the dtype

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
Comment on lines 113 to 121
config = LlamaConfig(**kwargs)
config.rope_theta = base
config.max_position_embeddings = max_position_embeddings
config.head_dim = dim # this one doesn't actually exist, will only be used in the deprecation transition
if scaling_factor == 1.0 and len(kwargs) == 0:
config.rope_scaling = None
else:
config.rope_scaling = {"type": "default", "factor": scaling_factor}
config.rope_scaling |= kwargs # may overwrite "type"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's fairly weird (init a config) but only happens once, should be alright

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the easiest path for the deprecation: in v4.45 we just delete these lines 👼

@gante gante marked this pull request as ready for review July 21, 2024 17:15
@gante
Copy link
Member Author

gante commented Jul 21, 2024

(all RoPE models with cache_positions upgraded, now fixing CI)

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just gave a quick look at the API which looks good to me. Very nice and clean changes with the deprecation cycle.

Thanks for iterating on the PR! (Would really like to have @amyeroberts take a look at the PR as well if possible)

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@ritwickchaudhry
Copy link

ritwickchaudhry commented Jul 29, 2024

I'm trying to train the Phi-3-small-128k-instruct model and the configuration loading leads to an error in the rope_validation function here because the config has more than 3 hyper-parameters which fails the check.

Will the PR fix this issue? If yes, when can we expect this to merge in main?

@ArthurZucker
Copy link
Collaborator

MMmm what's weird is that this model uses code on the hub.
Anyways if we broke something we can do a patch but we need a proper reproducer!

@ritwickchaudhry
Copy link

ritwickchaudhry commented Jul 30, 2024

MMmm what's weird is that this model uses code on the hub. Anyways if we broke something we can do a patch but we need a proper reproducer!

Way to Reproduce:

from transformers import  Phi3ForCausalLM
Phi3ForCausalLM.from_pretrained(<path/to/Phi3_small_128k_instruct>")

@ArthurZucker
Copy link
Collaborator

That model is "code on the hub" so it's kind of expected

@Fazziekey
Copy link

@gante hello can you help me to review this PR for fixed ntk scaling?

@gante
Copy link
Member Author

gante commented Sep 10, 2024

Note: splitting this PR into multiple smaller ones, as the refactor needs extra attention in some models (e.g. cohere's RoPE is not exactly the same as llama's)

Keeping the PR open as a reference, until all models have the new RoPE structure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Plans to Integrate LongRoPE into LLaMA?
8 participants