Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Allow tied embeddings #671

Open
neelnanda-io opened this issue Jul 12, 2024 · 1 comment
Open

[Proposal] Allow tied embeddings #671

neelnanda-io opened this issue Jul 12, 2024 · 1 comment
Labels
complexity-moderate Moderately complicated issues for people who have intermediate experience with the code enhancement New feature or request

Comments

@neelnanda-io
Copy link
Collaborator

Proposal

TransformerLens assumes all models have untied embeddings (ie W_U =/= W_E.T). This is good to assume in general, and needs to be true if LN is folded. But, it is more memory expensive.

This is particularly bad for Gemma models, which have tied embeddings and a very large vocab size, eg 25% of Gemma 2 2.6B's params is W_E, and 10% of Gemma 2 9B is W_E. I think it would be great to load the tied models by default with tied embeddings (so W_U.data = W_E.data.T), but a helper function to clone the matrix and make this untied if need be. This would involve adding a field for tied_embeddings to the Config which defaults to False, but can be set to True for select models like GPT-2 and Gemma and Gemma 2, but which gets set back to False if fold_layernorm is run.

I'd love people to be able to work with the Gemma 2 models with a bunch of SAEs in memory, so memory efficiency is important (and folding LayerNorm isn't that important)

@bryce13950
Copy link
Collaborator

I did a quick little experiment in the specific architecture weight conversions to see if it was sufficient for tying the weights when needed https://github.com/TransformerLensOrg/TransformerLens/tree/experiment-gemma-weight-tying. This is something that needs to be tested though. I am not sure if what I did here is sufficient to solve the issue, and this is the sort of change that I am a bit weary about, since it can probably ripple out if done incorrectly. If you have time to mess with my branch, that would be super helpful. I am pretty full on time for the next couple weeks wrapping some other things up, but once I do have time I would be happy to experiment with this a bit.
If it seems to work well, then I will probably revise the weight conversions to share a bit of code, so that these sorts of system wide changes can be made in a central place without too many issues.

@bryce13950 bryce13950 removed their assignment Jul 12, 2024
@bryce13950 bryce13950 added enhancement New feature or request complexity-moderate Moderately complicated issues for people who have intermediate experience with the code labels Jul 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
complexity-moderate Moderately complicated issues for people who have intermediate experience with the code enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants