Skip to content

Commit

Permalink
add dtype to kernel (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
MUCDK authored May 25, 2023
1 parent 795fabe commit 64d26ac
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/ott/solvers/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"kernel", self.kernel_init, (inputs.shape[-1], self.dim_hidden)
)
kernel = self.rectifier_fn(kernel)
kernel = jnp.asarray(kernel, self.dtype)
y = jax.lax.dot_general(
inputs,
kernel, (((inputs.ndim - 1,), (0,)), ((), ())),
Expand Down

0 comments on commit 64d26ac

Please sign in to comment.