Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple dtype argument addition #680

Merged
merged 4 commits into from
Mar 16, 2024

Conversation

AakashKumarNain
Copy link
Contributor

Simple addition of the dtype argument along with a very simple test (will increase the code coverage later).

A few thoughts:

  1. If we don't make dtype an attribute, then we won't be able to do certain "nice to have" things in the __call__. For example, what if the user passed a float32 array to a layer initialized with float16? Ideally we should warn the user to avoid silent bugs
  2. I am not sure what should be the right annotation for the dtype argument

@@ -23,6 +24,7 @@ def __init__(
in_features: Union[int, Literal["scalar"]],
out_features: Union[int, Literal["scalar"]],
use_bias: bool = True,
dtype=None,
Copy link
Contributor

@Artur-Galstyan Artur-Galstyan Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of dtype is missing here. I don't have my laptop for a couple of days but I think the type of dtype is jax.numpy.dtype. And it should also be something like "dtype: Optional[jax.numpy.dtype] = None"

@@ -33,6 +35,8 @@ def __init__(
- `out_features`: The output size. The output from the layer will be a vector
of shape `(out_features,)`.
- `use_bias`: Whether to add on a bias as well.
- `dtype`: The dtype to use. Defaults to either `jax.numpy.float32` or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also tell the user that it applies to all trainable parameters (not just the weights but also the bias) IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated this

@patrick-kidger
Copy link
Owner

I think I like this! I wouldn't worry about the type annotation, so the only thing I'd tweak is the docstring a little bit.

I don't think we need to make dtype an attribute / I wouldn't try to raise a warning in __call__. That opens up too great a can of worms, I think, as in general there are all kinds of cases where folks might want mixed dtypes.

@AakashKumarNain
Copy link
Contributor Author

I think I like this! I wouldn't worry about the type annotation, so the only thing I'd tweak is the docstring a little bit.

I have updated the docstring to reflect the effect of this argument in a better way.

I don't think we need to make dtype an attribute / I wouldn't try to raise a warning in call. That opens up too great a can of worms, I think, as in general there are all kinds of cases where folks might want mixed dtypes.

SGTM

@patrick-kidger
Copy link
Owner

Alright, I think I'm happy with this pattern! Since we're settled on that, do I understand correctly that you're now aiming to repeat this across other parts of eqx.nn?

@AakashKumarNain
Copy link
Contributor Author

Yes, I will send a separate PR for that. You can merge this one so that I am also unblocked for my Mistral based experiments.

@patrick-kidger patrick-kidger merged commit 3061c18 into patrick-kidger:dev Mar 16, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Okay, great!
Just merged into dev.

@AakashKumarNain
Copy link
Contributor Author

Thank you for the review and the help

@patrick-kidger patrick-kidger mentioned this pull request Apr 14, 2024
patrick-kidger pushed a commit that referenced this pull request Apr 14, 2024
* add dtype and format code

* add a simple test for checking dtype other than float32

* fix default dtype and format code

* refine documentation for the dtype argument
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants