From 0219671b59ae26758310449414662ade0b9fd649 Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Wed, 9 Oct 2024 19:41:20 +0200 Subject: [PATCH] fix default for intermediate quadratic layers in ICNN (#587) --- src/ott/neural/networks/icnn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ott/neural/networks/icnn.py b/src/ott/neural/networks/icnn.py index b6f3fd049..9ee3e5080 100644 --- a/src/ott/neural/networks/icnn.py +++ b/src/ott/neural/networks/icnn.py @@ -116,7 +116,8 @@ def _get_wx(self, dim: int, rank: int) -> nn.Module: num_potentials=dim, use_linear=True, use_bias=True, - kernel_diag_init=nn.initializers.zeros, + kernel_diag_init=nn.initializers.constant(-2.0), + rectifier_fn=jax.nn.softplus, kernel_lr_init=self.init_fn, kernel_linear_init=self.init_fn, bias_init=nn.initializers.zeros,