Skip to content

Commit

Permalink
Merge pull request #9052 from jpuigcerver:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 430680329
  • Loading branch information
jax authors committed Feb 24, 2022
2 parents a9a827e + 86e8928 commit 3948fde
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jax/_src/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
def zeros(key, shape, dtype=jnp.float_): return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
def ones(key, shape, dtype=jnp.float_): return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))

def constant(value, dtype=jnp.float_):
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return jnp.full(shape, value, dtype=dtype)
return init

def uniform(scale=1e-2, dtype=jnp.float_):
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
Expand Down
1 change: 1 addition & 0 deletions jax/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# flake8: noqa: F401
from jax._src.nn.initializers import (
constant as constant,
delta_orthogonal as delta_orthogonal,
glorot_normal as glorot_normal,
glorot_uniform as glorot_uniform,
Expand Down

0 comments on commit 3948fde

Please sign in to comment.