-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Adding dtype
to nn layers
#679
Comments
So one thing I have realised, that speaks in support of this, is that we actually have a One thing I definitely wouldn't do is put So I think the main debate is whether or not to just implement
All of the above is broadly just to make it easier than doing Right now I'm leaning towards the |
Does that mean passing the argument like |
Yup. |
That is perfect IMO! Will save a lot of headache without breaking anything |
Alright, SGTM. I'd be happy to take a PR on this! |
Sounds good. I will send a PR for one layer first to ensure that everything is aligned to what we just discussed. Thank you |
This can now be tracked in #680 |
But isn't all we have to do here to iterate over all leaves of the PyTree and replace them with weights of the given dype? Although the model = eqx.to_dtype(MyModel(...), jnp.float16) Because that's much more elegant and useful and I don't have to think about dtype precision when writing out the model. I'm pretty sure, you can easily do that with a clever combination of |
If you have something like a |
Moving the discussion from #673 to here. As discussed, adding
dtype
for creation of layers is necessary as it allows to load modules/layers directly with a specific precision (very useful in context of LLMs) as opposed to initializing with full precision, and then mapping to the desired precision.One thing that we obviously can do without introducing any breaking changes is to add
dtype
as an attribute to theModule
class that defaults to full precision but can be overridden by the end user for a specific precision. Like each layer comes withweight
andbias
, we add another attributedtype
insideModule
so that each layer now consists ofweight
,bias
, anddtype
.Questions to consider:
Initializer
abstraction that handle all these stuff from dtype to custom kernel initializer, and use an instance of it inside theModule
class?The text was updated successfully, but these errors were encountered: