Skip to content

Commit

Permalink
Restructure neural module (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 authored Nov 20, 2023
1 parent 941d610 commit b961b3b
Show file tree
Hide file tree
Showing 23 changed files with 404 additions and 458 deletions.
2 changes: 1 addition & 1 deletion docs/solvers/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Models
.. autosummary::
:toctree: _autosummary

models.ModelBase
models.BaseW2NeuralDual
models.ICNN
models.MLP

Expand Down
20 changes: 19 additions & 1 deletion src/ott/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import geometry, initializers, math, problems, solvers, tools, utils
import contextlib

from . import (
datasets,
geometry,
initializers,
math,
problems,
solvers,
tools,
utils,
)

with contextlib.suppress(ImportError):
# TODO(michalk8): add warning that neural module is not imported
from . import neural

from ._version import __version__

del contextlib
File renamed without changes.
6 changes: 0 additions & 6 deletions src/ott/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib

from . import linear, quadratic

with contextlib.suppress(ImportError):
from . import nn
del contextlib
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import initializers
from . import models, solvers
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import dataset
from . import conjugate_solvers, layers, models
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.
from typing import Any, Callable, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax import linen as nn

__all__ = ["PositiveDense", "PosDefPotentials"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,214 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.core import frozen_dict
from flax.training import train_state
from jax import numpy as jnp
from jax.nn import initializers

from ott import utils
from ott.geometry import geometry
from ott.initializers.linear import initializers
from ott.initializers.linear import initializers as lin_init
from ott.math import matrix_square_root
from ott.neural.models import layers
from ott.neural.solvers import neuraldual
from ott.problems.linear import linear_problem

if TYPE_CHECKING:
from ott.problems.linear import linear_problem
__all__ = ["ICNN", "MLP", "MetaInitializer"]

# TODO(michalk8): add initializer for NeuralDual?
__all__ = ["MetaInitializer", "MetaMLP"]

class ICNN(neuraldual.BaseW2NeuralDual):
"""Input convex neural network (ICNN) architecture with initialization.
Implementation of input convex neural networks as introduced in
:cite:`amos:17` with initialization schemes proposed by :cite:`bunne:22`.
Args:
dim_data: data dimensionality.
dim_hidden: sequence specifying size of hidden dimensions. The
output dimension of the last layer is 1 by default.
init_std: value of standard deviation of weight initialization method.
init_fn: choice of initialization method for weight matrices (default:
:func:`jax.nn.initializers.normal`).
act_fn: choice of activation function used in network architecture
(needs to be convex, default: :obj:`jax.nn.relu`).
pos_weights: Enforce positive weights with a projection.
If ``False``, the positive weights should be enforced with clipping
or regularization in the loss.
gaussian_map_samples: Tuple of source and target points, used to initialize
the ICNN to mimic the linear Bures map that morphs the (Gaussian
approximation) of the input measure to that of the target measure. If
``None``, the identity initialization is used, and ICNN mimics half the
squared Euclidean norm.
"""
dim_data: int
dim_hidden: Sequence[int]
init_std: float = 1e-2
init_fn: Callable = jax.nn.initializers.normal
act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
pos_weights: bool = True
gaussian_map_samples: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None

@property
def is_potential(self) -> bool: # noqa: D102
return True

def setup(self) -> None: # noqa: D102
self.num_hidden = len(self.dim_hidden)

if self.pos_weights:
hid_dense = layers.PositiveDense
# this function needs to be the inverse map of function
# used in PositiveDense layers
rescale = hid_dense.inv_rectifier_fn
else:
hid_dense = nn.Dense
rescale = lambda x: x
self.use_init = False
# check if Gaussian map was provided
if self.gaussian_map_samples is not None:
factor, mean = self._compute_gaussian_map_params(
self.gaussian_map_samples
)
else:
factor, mean = self._compute_identity_map_params(self.dim_data)

w_zs = []
# keep track of previous size to normalize accordingly
normalization = 1

for i in range(1, self.num_hidden):
w_zs.append(
hid_dense(
self.dim_hidden[i],
kernel_init=initializers.constant(rescale(1.0 / normalization)),
use_bias=False,
)
)
normalization = self.dim_hidden[i]
# final layer computes average, still with normalized rescaling
w_zs.append(
hid_dense(
1,
kernel_init=initializers.constant(rescale(1.0 / normalization)),
use_bias=False,
)
)
self.w_zs = w_zs

# positive definite potential (the identity mapping or linear OT)
self.pos_def_potential = layers.PosDefPotentials(
self.dim_data,
num_potentials=1,
kernel_init=lambda *_: factor,
bias_init=lambda *_: mean,
use_bias=True,
)

# subsequent layers re-injected into convex functions
w_xs = []
for i in range(self.num_hidden):
w_xs.append(
nn.Dense(
self.dim_hidden[i],
kernel_init=self.init_fn(self.init_std),
bias_init=initializers.constant(0.),
use_bias=True,
)
)
# final layer, to output number
w_xs.append(
nn.Dense(
1,
kernel_init=self.init_fn(self.init_std),
bias_init=initializers.constant(0.),
use_bias=True,
)
)
self.w_xs = w_xs

@staticmethod
def _compute_gaussian_map_params(
samples: Tuple[jnp.ndarray, jnp.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
from ott.tools.gaussian_mixture import gaussian
source, target = samples
# print(source)
# print(type(source))
g_s = gaussian.Gaussian.from_samples(source)
g_t = gaussian.Gaussian.from_samples(target)
lin_op = g_s.scale.gaussian_map(g_t.scale)
b = jnp.squeeze(g_t.loc) - jnp.linalg.solve(lin_op, jnp.squeeze(g_t.loc))
lin_op = matrix_square_root.sqrtm_only(lin_op)
return jnp.expand_dims(lin_op, 0), jnp.expand_dims(b, 0)

@staticmethod
def _compute_identity_map_params(
input_dim: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
A = jnp.eye(input_dim).reshape((1, input_dim, input_dim))
b = jnp.zeros((1, input_dim))
return A, b

@nn.compact
def __call__(self, x: jnp.ndarray) -> float: # noqa: D102
z = self.act_fn(self.w_xs[0](x))
for i in range(self.num_hidden):
z = jnp.add(self.w_zs[i](z), self.w_xs[i + 1](x))
z = self.act_fn(z)
z += self.pos_def_potential(x)
return z.squeeze()


class MLP(neuraldual.BaseW2NeuralDual):
"""A generic, typically not-convex (w.r.t input) MLP.
Args:
dim_hidden: sequence specifying size of hidden dimensions. The output
dimension of the last layer is automatically set to 1 if
:attr:`is_potential` is ``True``, or the dimension of the input otherwise
is_potential: Model the potential if ``True``, otherwise
model the gradient of the potential
act_fn: Activation function
"""

dim_hidden: Sequence[int]
is_potential: bool = True
act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.leaky_relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102
squeeze = x.ndim == 1
if squeeze:
x = jnp.expand_dims(x, 0)
assert x.ndim == 2, x.ndim
n_input = x.shape[-1]

z = x
for n_hidden in self.dim_hidden:
Wx = nn.Dense(n_hidden, use_bias=True)
z = self.act_fn(Wx(z))

if self.is_potential:
Wx = nn.Dense(1, use_bias=True)
z = Wx(z).squeeze(-1)

quad_term = 0.5 * jax.vmap(jnp.dot)(x, x)
z += quad_term
else:
Wx = nn.Dense(n_input, use_bias=True)
z = x + Wx(z)

return z.squeeze(0) if squeeze else z


@jax.tree_util.register_pytree_node_class
class MetaInitializer(initializers.DefaultInitializer):
class MetaInitializer(lin_init.DefaultInitializer):
"""Meta OT Initializer with a fixed geometry :cite:`amos:22`.
This initializer consists of a predictive model that outputs the
Expand All @@ -44,13 +230,12 @@ class MetaInitializer(initializers.DefaultInitializer):
The model's parameters are learned using a training set of OT
instances (multiple pairs of probability weights), that assume the
**same** geometry ``geom`` is used throughout, both for training and
evaluation. The meta model defaults to the MLP in
:class:`~ott.initializers.nn.initializers.MetaMLP` and, with batched problem
instances passed into :meth:`update`.
evaluation.
Args:
geom: The fixed geometry of the problem instances.
meta_model: The model to predict the potential :math:`f` from the measures.
TODO(marcocuturi): add explanation here what arguments to expect.
opt: The optimizer to update the parameters. If ``None``, use
:func:`optax.adam` with :math:`0.001` learning rate.
rng: The PRNG key to use for initializing the model.
Expand All @@ -75,7 +260,7 @@ class MetaInitializer(initializers.DefaultInitializer):
def __init__(
self,
geom: geometry.Geometry,
meta_model: Optional[nn.Module] = None,
meta_model: nn.Module,
opt: Optional[optax.GradientTransformation
] = optax.adam(learning_rate=1e-3), # noqa: B008
rng: Optional[jax.Array] = None,
Expand All @@ -87,9 +272,8 @@ def __init__(
self.rng = utils.default_prng_key(rng)

na, nb = geom.shape
self.meta_model = MetaMLP(
potential_size=na
) if meta_model is None else meta_model
# TODO(michalk8): add again some default MLP
self.meta_model = meta_model

if state is None:
# Initialize the model's training state.
Expand Down Expand Up @@ -219,37 +403,3 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
"rng": self.rng,
"state": self.state
}


class MetaMLP(nn.Module):
r"""Potential for :class:`~ott.initializers.nn.initializers.MetaInitializer`.
This provides an MLP :math:`\hat f_\theta(a, b)` that maps from the
probabilities of the measures to the optimal dual potentials :math:`f`.
Args:
potential_size: The dimensionality of :math:`f`.
num_hidden_units: The number of hidden units in each layer.
num_hidden_layers: The number of hidden layers.
"""

potential_size: int
num_hidden_units: int = 512
num_hidden_layers: int = 3

@nn.compact
def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
r"""Make a prediction.
Args:
a: Probabilities of the :math:`\alpha` measure's atoms.
b: Probabilities of the :math:`\beta` measure's atoms.
Returns:
The :math:`f` potential.
"""
dtype = a.dtype
z = jnp.concatenate((a, b))
for _ in range(self.num_hidden_layers):
z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z))
return nn.Dense(self.potential_size, dtype=dtype)(z)
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import conjugate_solvers, layers, losses, models, neuraldual
from . import losses, map_estimator, neuraldual
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from flax.training import train_state

from ott import utils
from ott.solvers.nn import models
from ott.neural.solvers import neuraldual

__all__ = ["MapEstimator"]

Expand Down Expand Up @@ -77,7 +77,7 @@ class MapEstimator:
def __init__(
self,
dim_data: int,
model: models.ModelBase,
model: neuraldual.BaseW2NeuralDual,
optimizer: Optional[optax.OptState] = None,
fitting_loss: Optional[Callable[[jnp.ndarray, jnp.ndarray],
Tuple[float, Optional[Any]]]] = None,
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
def setup(
self,
dim_data: int,
neural_net: models.ModelBase,
neural_net: neuraldual.BaseW2NeuralDual,
optimizer: optax.OptState,
):
"""Setup all components required to train the network."""
Expand Down
Loading

0 comments on commit b961b3b

Please sign in to comment.