Skip to content

Commit

Permalink
Stax: Add embedding layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaehunro committed Feb 4, 2020
1 parent ddc83e0 commit 3898033
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion jax/experimental/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

from jax.nn import (relu, log_softmax, softmax, softplus, sigmoid, elu,
leaky_relu, selu, gelu, normalize)
from jax.nn.initializers import glorot_normal, normal, ones, zeros
from jax.nn.initializers import glorot_normal, normal, ones, uniform, zeros
from jax.ops import index_update

# aliases for backwards compatibility
glorot = glorot_normal
Expand Down Expand Up @@ -144,6 +145,27 @@ def apply_fun(params, x, **kwargs):
return init_fun, apply_fun


def Embedding(vocab_size,
embedding_size,
padding_idx=None,
embedding_init=uniform()):
"""Layer construction function for an embedding layer."""

def init_fun(rng, input_shape):
embedding_shape = (vocab_size, embedding_size)
embedding_table = embedding_init(rng, embedding_shape)
if padding_idx is not None:
embedding_table = index_update(embedding_table, padding_idx, 0.)
output_shape = input_shape + (embedding_size,)
return output_shape, (embedding_table,)

def apply_fun(params, inputs, **kwargs):
embedding_table = params[0]
return embedding_table[inputs]

return init_fun, apply_fun


def elementwise(fun, **fun_kwargs):
"""Layer that applies a scalar function elementwise on its inputs."""
init_fun = lambda rng, input_shape: (input_shape, ())
Expand Down

0 comments on commit 3898033

Please sign in to comment.