-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple dtype
argument addition
#680
Conversation
@@ -23,6 +24,7 @@ def __init__( | |||
in_features: Union[int, Literal["scalar"]], | |||
out_features: Union[int, Literal["scalar"]], | |||
use_bias: bool = True, | |||
dtype=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type of dtype is missing here. I don't have my laptop for a couple of days but I think the type of dtype is jax.numpy.dtype. And it should also be something like "dtype: Optional[jax.numpy.dtype] = None"
equinox/nn/_linear.py
Outdated
@@ -33,6 +35,8 @@ def __init__( | |||
- `out_features`: The output size. The output from the layer will be a vector | |||
of shape `(out_features,)`. | |||
- `use_bias`: Whether to add on a bias as well. | |||
- `dtype`: The dtype to use. Defaults to either `jax.numpy.float32` or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also tell the user that it applies to all trainable parameters (not just the weights but also the bias) IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated this
I think I like this! I wouldn't worry about the type annotation, so the only thing I'd tweak is the docstring a little bit. I don't think we need to make dtype an attribute / I wouldn't try to raise a warning in |
I have updated the docstring to reflect the effect of this argument in a better way.
SGTM |
Alright, I think I'm happy with this pattern! Since we're settled on that, do I understand correctly that you're now aiming to repeat this across other parts of |
Yes, I will send a separate PR for that. You can merge this one so that I am also unblocked for my Mistral based experiments. |
Okay, great! |
Thank you for the review and the help |
Simple addition of the dtype argument along with a very simple test (will increase the code coverage later).
A few thoughts:
__call__
. For example, what if the user passed afloat32
array to a layer initialized withfloat16
? Ideally we should warn the user to avoid silent bugs