-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple dtype
argument addition
#680
Merged
Merged
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
e3284da
add dtype and format code
AakashKumarNain 4c48a90
add a simple test for checking dtype other than float32
AakashKumarNain da2ab5a
fix default dtype and format code
AakashKumarNain 65cc913
refine documentation for the dtype argument
AakashKumarNain File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
import jax.random as jrandom | ||
from jaxtyping import Array, PRNGKeyArray | ||
|
||
from .._misc import default_floating_dtype | ||
from .._module import field, Module | ||
|
||
|
||
|
@@ -23,6 +24,7 @@ def __init__( | |
in_features: Union[int, Literal["scalar"]], | ||
out_features: Union[int, Literal["scalar"]], | ||
use_bias: bool = True, | ||
dtype=None, | ||
*, | ||
key: PRNGKeyArray, | ||
): | ||
|
@@ -33,6 +35,8 @@ def __init__( | |
- `out_features`: The output size. The output from the layer will be a vector | ||
of shape `(out_features,)`. | ||
- `use_bias`: Whether to add on a bias as well. | ||
- `dtype`: The dtype to use. Defaults to either `jax.numpy.float32` or | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should also tell the user that it applies to all trainable parameters (not just the weights but also the bias) IMO There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have updated this |
||
`jax.numpy.float64` depending on whether JAX is in 64-bit mode. | ||
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter | ||
initialisation. (Keyword only argument.) | ||
|
||
|
@@ -46,11 +50,17 @@ def __init__( | |
in_features_ = 1 if in_features == "scalar" else in_features | ||
out_features_ = 1 if out_features == "scalar" else out_features | ||
lim = 1 / math.sqrt(in_features_) | ||
|
||
if dtype is None: | ||
dtype = default_floating_dtype() | ||
|
||
self.weight = jrandom.uniform( | ||
wkey, (out_features_, in_features_), minval=-lim, maxval=lim | ||
wkey, (out_features_, in_features_), minval=-lim, maxval=lim, dtype=dtype | ||
) | ||
if use_bias: | ||
self.bias = jrandom.uniform(bkey, (out_features_,), minval=-lim, maxval=lim) | ||
self.bias = jrandom.uniform( | ||
bkey, (out_features_,), minval=-lim, maxval=lim, dtype=dtype | ||
) | ||
else: | ||
self.bias = None | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type of dtype is missing here. I don't have my laptop for a couple of days but I think the type of dtype is jax.numpy.dtype. And it should also be something like "dtype: Optional[jax.numpy.dtype] = None"