-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Falcon: Add RoPE scaling #25878
Falcon: Add RoPE scaling #25878
Conversation
if self.rope_scaling is None: | ||
return | ||
|
||
if self.rotary: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is copy/paste from #24653 except for this if
block
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
emb = torch.cat((freqs, freqs), dim=-1).to(device) | ||
def _set_cos_sin_cache(self, seq_len, device, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separates the logic for the creation of self.cos_cached
and self.sin_cached
, since these are the only bits the other scaling factors need to overwrite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me!
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, details seem correct when I went through it and it shouldn't have any backward compatibility implications. Still, please check with the authors before merging it!
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
emb = torch.cat((freqs, freqs), dim=-1).to(device) | ||
def _set_cos_sin_cache(self, seq_len, device, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! 🤗
Just a few implementation questions / requests.
rope_theta (`float`, *optional*, defaults to 10000.0): | ||
The base period of the RoPE embeddings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't this part of the rope scaling dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this isn't a rope scaling parameter, it is a constant used to compute the embeddings. As such, I agree with the original decision to keep it separate, in the CodeLlama PR :)
Coincidently, it can also be used for scaling (i.e. by increasing it then fine-tuning the resulting model).
if self.config.rope_scaling is None: | ||
rotary_emb = FalconRotaryEmbedding( | ||
self.head_dim, | ||
base=self.config.rope_theta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realise this is copied elsewhere and not from this PR - but this is an indication to me this parameter is misnamed. If no rope scaling is happening, then a rope parameter shouldn't be used here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see comment above, I believe the two questions are related)
def _set_cos_sin_cache(self, seq_len, device, dtype): | ||
self.seq_len_cached = seq_len | ||
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you re-write this with classic matrix operations? Unfortunately einsum creates issues when using traced models atm :(
|
||
def _set_cos_sin_cache(self, seq_len, device, dtype): | ||
self.seq_len_cached = seq_len | ||
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here re the parameter 't'
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache | ||
t = t / self.scaling_factor | ||
|
||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here re einsum
base = self.base * ( | ||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) | ||
) ** (self.head_dim / (self.head_dim - 2)) | ||
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to conditionally call .float
here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we want to keep these auxiliary calculations in fp32 and then cast down the final results (sin_cache
and cos_cache
) if needed :)
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments about 't' and einsum
I've also tested this PR against the thing I wanted to test, it is working correctly with and without RoPE scaling! Merging :) |
What does this PR do?
In the same spirit as #24653, adds RoPE scaling to Falcon. It also borrows a few changes from #25740, to allow for
codellama
-style scaling.In addition to the changes above, it also adds the
max_position_embeddings
to the config attributes, needed for one of the scaling strategies.Python script to validate these changes: https://pastebin.com/SJmUpDU1
Before this PR 👉 outputs gibberish
After this PR 👉 recognizes that the super large prompt is about llama 2