Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Falcon: Add RoPE scaling #25878

Merged
merged 2 commits into from
Sep 1, 2023
Merged

Falcon: Add RoPE scaling #25878

merged 2 commits into from
Sep 1, 2023

Conversation

gante
Copy link
Member

@gante gante commented Aug 30, 2023

What does this PR do?

In the same spirit as #24653, adds RoPE scaling to Falcon. It also borrows a few changes from #25740, to allow for codellama-style scaling.

In addition to the changes above, it also adds the max_position_embeddings to the config attributes, needed for one of the scaling strategies.


Python script to validate these changes: https://pastebin.com/SJmUpDU1
Before this PR 👉 outputs gibberish
After this PR 👉 recognizes that the super large prompt is about llama 2

if self.rope_scaling is None:
return

if self.rotary:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is copy/paste from #24653 except for this if block

t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separates the logic for the creation of self.cos_cached and self.sin_cached, since these are the only bits the other scaling factors need to overwrite

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 30, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, details seem correct when I went through it and it shouldn't have any backward compatibility implications. Still, please check with the authors before merging it!

t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me!

src/transformers/models/falcon/configuration_falcon.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! 🤗

Just a few implementation questions / requests.

src/transformers/models/falcon/configuration_falcon.py Outdated Show resolved Hide resolved
Comment on lines +78 to +79
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this part of the rope scaling dict?

Copy link
Member Author

@gante gante Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this isn't a rope scaling parameter, it is a constant used to compute the embeddings. As such, I agree with the original decision to keep it separate, in the CodeLlama PR :)

Coincidently, it can also be used for scaling (i.e. by increasing it then fine-tuning the resulting model).

src/transformers/models/falcon/configuration_falcon.py Outdated Show resolved Hide resolved
src/transformers/models/falcon/configuration_falcon.py Outdated Show resolved Hide resolved
if self.config.rope_scaling is None:
rotary_emb = FalconRotaryEmbedding(
self.head_dim,
base=self.config.rope_theta,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise this is copied elsewhere and not from this PR - but this is an indication to me this parameter is misnamed. If no rope scaling is happening, then a rope parameter shouldn't be used here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see comment above, I believe the two questions are related)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you re-write this with classic matrix operations? Unfortunately einsum creates issues when using traced models atm :(


def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re the parameter 't'

# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here re einsum

base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.head_dim / (self.head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to conditionally call .float here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we want to keep these auxiliary calculations in fp32 and then cast down the final results (sin_cache and cos_cache) if needed :)

Comment on lines +161 to +162
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments about 't' and einsum

@gante
Copy link
Member Author

gante commented Sep 1, 2023

t and einsum PR comments 👉 as discussed offline, this will be fixed in a follow-up PR

I've also tested this PR against the thing I wanted to test, it is working correctly with and without RoPE scaling!

Merging :)

@gante gante merged commit 53e2fd7 into huggingface:main Sep 1, 2023
@gante gante deleted the falcon_rope_scaling branch September 1, 2023 11:05
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants