Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Adding dtype to nn layers #679

Open
AakashKumarNain opened this issue Mar 10, 2024 · 9 comments
Open

[Enhancement] Adding dtype to nn layers #679

AakashKumarNain opened this issue Mar 10, 2024 · 9 comments
Labels
feature New feature

Comments

@AakashKumarNain
Copy link
Contributor

Moving the discussion from #673 to here. As discussed, adding dtype for creation of layers is necessary as it allows to load modules/layers directly with a specific precision (very useful in context of LLMs) as opposed to initializing with full precision, and then mapping to the desired precision.

One thing that we obviously can do without introducing any breaking changes is to add dtype as an attribute to the Module class that defaults to full precision but can be overridden by the end user for a specific precision. Like each layer comes with weight and bias, we add another attribute dtype inside Module so that each layer now consists of weight, bias, and dtype.

Questions to consider:

  1. Are there any major concerns with this approach? If yes, do we have a simple example to demonstrate that?
  2. Should we create an Initializer abstraction that handle all these stuff from dtype to custom kernel initializer, and use an instance of it inside the Module class?
@patrick-kidger
Copy link
Owner

So one thing I have realised, that speaks in support of this, is that we actually have a dtype argument to BatchNorm already! Although it only handles the dtype of the moving statistics, not (by design) the weight and bias. But we could change that.

One thing I definitely wouldn't do is put dtype as an attribute on Module, as we use them for things other than neural networks -- e.g. differential equation solvers.

So I think the main debate is whether or not to just implement dtype, or to go for a more fully-fledged initialiser abstraction:

  • General initialisers: these are more general, which is good. However this might end up being a kind of API nightmare, where we find ourselves wanting a way to specify how to customly initialise every single weight of e.g. a MultiheadAttention layer. For example, we might want that when loading some pretrained weights from somewhere. It also means a lot of work implementing an eqx.nn.init module to hold all of the common possibilities...
  • Just a dtype argument: easy to provide as a way to set the dtype of all submodules and weights (probably unusual that we'd ever want mixed dtypes), and easy to understand+implement. However a bit less general.

All of the above is broadly just to make it easier than doing model = eqx.filter_eval_shape(MyModel, **hyperparams); model = eqx.tree_at(..., model, ...), which is admittedly not terribly elegant.

Right now I'm leaning towards the dtype argument approach, mostly just because it's much much simpler to implement...

@patrick-kidger patrick-kidger added the feature New feature label Mar 10, 2024
@AakashKumarNain
Copy link
Contributor Author

Just a dtype argument: easy to provide as a way to set the dtype of all submodules and weights (probably unusual that we'd ever want mixed dtypes), and easy to understand+implement. However a bit less general.

Does that mean passing the argument like eqx.nn.Linear(in, out, dtype=dtype)?

@patrick-kidger
Copy link
Owner

Yup.

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented Mar 10, 2024

That is perfect IMO! Will save a lot of headache without breaking anything

@patrick-kidger
Copy link
Owner

Alright, SGTM. I'd be happy to take a PR on this!
(Against the dev branch.)

@AakashKumarNain
Copy link
Contributor Author

Sounds good. I will send a PR for one layer first to ensure that everything is aligned to what we just discussed. Thank you

@AakashKumarNain
Copy link
Contributor Author

This can now be tracked in #680

@Artur-Galstyan
Copy link
Contributor

Artur-Galstyan commented Mar 13, 2024

So I think the main debate is whether or not to just implement dtype, or to go for a more fully-fledged initialiser abstraction:

  • General initialisers: these are more general, which is good. However this might end up being a kind of API nightmare, where we find ourselves wanting a way to specify how to customly initialise every single weight of e.g. a MultiheadAttention layer. For example, we might want that when loading some pretrained weights from somewhere. It also means a lot of work implementing an eqx.nn.init module to hold all of the common possibilities...
  • Just a dtype argument: easy to provide as a way to set the dtype of all submodules and weights (probably unusual that we'd ever want mixed dtypes), and easy to understand+implement. However a bit less general.

All of the above is broadly just to make it easier than doing model = eqx.filter_eval_shape(MyModel, **hyperparams); model = eqx.tree_at(..., model, ...), which is admittedly not terribly elegant.

Right now I'm leaning towards the dtype argument approach, mostly just because it's much much simpler to implement...

But isn't all we have to do here to iterate over all leaves of the PyTree and replace them with weights of the given dype?

Although the dtype argument approach is simple and would definately work, I wouldn't want to add dtypes everywhere because it's somewhat reminiscent of boilerplate and instead do something like

model = eqx.to_dtype(MyModel(...), jnp.float16)

Because that's much more elegant and useful and I don't have to think about dtype precision when writing out the model.

I'm pretty sure, you can easily do that with a clever combination of jax.tree.leaves, jax.tree.map and eqx.filter(..., eqx.is_array). I wrote something here #602, which manipulates the dtypes of the PyTree. We'd just have to simplify that a bit and adjust it to our needs and then we should be done, right? Or am I missing something here?

@AakashKumarNain
Copy link
Contributor Author

Or am I missing something here?

If you have something like a TransformerBlock that is repeated n times, then initializing it in full precision and then converting to the required precision would result in OOM in case of most of LLMs. You won't be able to use eqx.filter_vmap(...) to create a stacked layer, and would end up using a loop where each block is converted to the desired precision as soon as it is created

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants