From e88f61f636da6645ea90664355500d40f2c9437f Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Sat, 9 Nov 2024 00:35:03 +0100 Subject: [PATCH] add glossary (#594) * add glossary * fixes * fixes * fixes * update lineax version * typo * fixes * Fix typos --- docs/glossary.rst | 298 ++++++++++++++++++ docs/index.rst | 11 +- docs/math.rst | 4 +- docs/spelling/technical.txt | 7 + pyproject.toml | 6 +- .../neural/methods/expectile_neural_dual.py | 4 +- src/ott/neural/methods/monge_gap.py | 3 +- .../linear/implicit_differentiation.py | 35 +- src/ott/solvers/linear/sinkhorn.py | 241 ++++---------- .../solvers/quadratic/gromov_wasserstein.py | 2 +- tests/solvers/linear/sinkhorn_diff_test.py | 1 + 11 files changed, 412 insertions(+), 200 deletions(-) create mode 100644 docs/glossary.rst diff --git a/docs/glossary.rst b/docs/glossary.rst new file mode 100644 index 000000000..3372ff651 --- /dev/null +++ b/docs/glossary.rst @@ -0,0 +1,298 @@ +Glossary +======== + +.. glossary:: + :sorted: + + coupling + A coupling of two probability measures :math:`\mu` and :math:`\nu` is a + probability measure on the product space of their respective supports. + When the coupling is balanced, the first and second marginals of that + probability measure must coincide with :math:`\mu` and + :math:`\nu` respectively. Equivalently, given two non-negative vectors + :math:`a\in\mathbb{R}^n` and :math:`b\in\mathbb{R}^m`, a coupling is a + matrix form is a non-negative matrix :math:`P` of size + :math:`n\times m`. When the coupling is balanced :math:`P` is in their + :term:`transportation polytope` :math:`U(a,b)`. + + dual Kantorovich potentials + Real-valued functions or vectors that solve the + :term:`dual Kantorovich problem`. + + dual Kantorovich problem + Dual formulation of the :term:`Kantorovich problem`, seeking two + vectors, two :term:`dual Kantorovich potentials`, such that, given a + cost matrix :math:`C` of size ``[n, m]`` and two + probability vectors :math:`a \in\mathbb{R}^n,b\in\mathbb{R}^m`, they + belong to the :term:`dual transportation polyhedron` :math:`D(C)` and + maximize: + + .. math:: + + \max_{f,g \,\in D(C)} \langle f,a \rangle + \langle g,b \rangle. + + This problem admits a continuous formulation between two probability + distributions :math:`\mu,\nu`: + + .. math:: + + \max_{f\oplus g\leq c} \int f d\mu + \int g d\nu, + + where :math:`f,g` are real-valued functions on the supports of + :math:`\mu,\nu` and :math:`f\oplus g\leq c` means that for any pair + :math:`x,y` in the respective supports, :math:`f(x)+g(y)\leq c(x,y)`. + + dual transportation polyhedron + Given a :math:`n\times m` cost matrix :math:`C`, denotes the set of + pairs of vectors + + .. math:: + + D(C):= \{f \in\mathbb{R}^n, g \in \mathbb{R}^m + | f_i + g_j \leq C_{ij}\}. + + dualize + Within the context of optimization, the process of converting a + constrained optimization problem into an unconstrained one, by + transforming constraints into penalty terms in the objective function. + + entropy-regularized optimal transport + The data of the entropy regularized OT (EOT) problem is parameterized by + a cost matrix :math:`C` of size ``[n, m]`` and two vectors :math:`a,b` + of non-negative weights of respective size ``n`` and ``m``. + The parameters of the EOT problem consist of three numbers + :math:`\varepsilon, \tau_a, \tau_b`. + + The optimization variables are a pair of vectors of sizes ``n`` and + ``m`` denoted as :math:`f` and :math:`g`. + + Using the reparameterization for :math:`\rho_a` and + :math:`\rho_b` using + :math:`\tau_a=\rho_a /(\varepsilon + \rho_a)` and + :math:`\tau_b=\rho_b /(\varepsilon + \rho_b)`, the EOT optimization + problem reads: + + .. math:: + + \max_{f, g} - \langle a, \phi_a^{*}(-f) \rangle - \langle b, + \phi_b^{*}(-g) \rangle - \varepsilon \left(\langle e^{f/\varepsilon}, + e^{-C/\varepsilon} e^{g/\varepsilon} \rangle -|a||b|\right) + + where :math:`\phi_a(z) = \rho_a z(\log z - 1)` is a scaled entropy, and + :math:`\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}`, its Legendre transform + :cite:`sejourne:19`. + + That problem can also be written, instead, using positive scaling + vectors `u`, `v` of size ``n``, ``m``, and the kernel matrix + :math:`K := e^{-C/\varepsilon}`, as + + .. math:: + + \max_{u, v >0} - \langle a,\phi_a^{*}(-\varepsilon\log u) \rangle + + \langle b, \phi_b^{*}(-\varepsilon\log v) \rangle - + \langle u, K v \rangle + + Both of these problems can be written with a *primal* formulation, that + solves the :term:`unbalanced` optimal transport problem with a variable + matrix :math:`P` of size ``n`` x ``m`` and positive entries: + + .. math:: + + \min_{P>0} \langle P,C \rangle +\varepsilon \text{KL}(P | ab^T) + + \rho_a \text{KL}(P\mathbf{1}_m | a) + + \rho_b \text{KL}(P^T \mathbf{1}_n | b) + + where :math:`\text{KL}` is the generalized Kullback-Leibler divergence. + + The very same primal problem can also be written using a kernel + :math:`K` instead of a cost :math:`C` as well: + + .. math:: + + \min_{P>0}\, \varepsilon \text{KL}(P|K) + + \rho_a \text{KL}(P\mathbf{1}_m | a) + + \rho_b \text{KL}(P^T \mathbf{1}_n | b) + + The *original* OT problem taught in linear programming courses is + recovered by using the formulation above relying on the cost :math:`C`, + and letting :math:`\varepsilon \rightarrow 0`, and + :math:`\rho_a, \rho_b \rightarrow \infty`. + In that case the entropy disappears, whereas the :math:`\text{KL}` + regularization above become constraints on the marginals of :math:`P`: + This results in a standard min cost flow problem also called the + :term:`Kantorovich problem`. + + The *balanced* regularized OT problem is recovered for finite + :math:`\varepsilon > 0` but letting :math:`\rho_a, \rho_b \rightarrow + \infty`. This problem can be shown to be equivalent to a matrix scaling + problem, which can be solved using the :term:`Sinkhorn algorithm`. + To handle the case :math:`\rho_a, \rho_b \rightarrow \infty`, the + Sinkhorn function uses parameters ``tau_a`` and ``tau_b`` equal + respectively to :math:`\rho_a /(\varepsilon + \rho_a)` and + :math:`\rho_b / (\varepsilon + \rho_b)` instead. Setting either of these + parameters to 1 corresponds to setting the corresponding + :math:`\rho_a, \rho_b` to :math:`\infty`. + + envelope theorem + The envelope theorem is a major result about the differentiability + properties of the value function of a parameterized optimization + problem. Namely, that for a function :math:`f` defined implicitly as an + optimal objective parameterized by a vector :math:`x`, + + .. math:: + h(x):=\min_z s(x,z), z^\star(x):=\arg\min_z s(x,z) + + one has + + .. math:: + \nabla h(x)=\nabla_1 s(x,z^\star(x)) + + stating in effect that the optimal :math:`z^\star(x)` does not + need to be differentiated w.r.t. :math:`x` when computing the + gradient of :math:`h`. Note that this result is not valid for higher + order differentiation. + + ground cost + A real-valued function of two variables, :math:`c(x,y)` that describes + the cost needed to displace a point :math:`x` in a source measure to + :math:`y` in a target measure. + + implicit differentiation + Differentiation technique to compute the vector-Jacobian + product of the minimizer of an optimization procedure by considering + that small variations in the input would still result in minimizers + that verify optimality conditions (KKT or first-order conditions). These + identities can then help recover the vector-Jacobian operator by + inverting a linear system. + + input-convex neural networks + A neural network architecture for vectors with a few distinguishing + features: some parameters of this NN must be non-negative, the NN's + output is real-valued and guaranteed to be convex in the input vector. + + Kantorovich problem + Linear program that is the original formulation of optimal transport + between two point-clouds, seeking an optimal :term:`coupling` matrix + :math:`P`. The problem is parameterized by a cost matrix :math:`C` of + size ``[n, m]`` and two probability vectors :math:`a,b` of non-negative + weights of respective sizes ``n`` and ``m``, summing to :math:`1`. + The :term:`coupling` is in the :term:`transportation polytope` + :math:`U(a,b)` and must minimize the objective + + .. math:: + + \min_{P \in U(a,b)} \langle P,C \rangle = \sum_{ij} P_{ij} C_{ij}. + + This linear program can be seen as the primal problem of the + :term:`dual Kantorovich problem`. Alternatively, this problem admits a + continuous formulation between two probability distributions + :math:`\mu,\nu`: + + .. math:: + + \min_{\pi \in \Pi(\mu,\nu)} \iint cd\pi. + + where :math:`\pi` is a coupling density with first marginal :math:`\mu` + and second marginal :math:`\nu`. + + matching + A bijective pairing between two families of points of the same size + :math:`N`, parameterized using a permutation of size :math:`N`. + + multimarginal coupling + A multimarginal coupling of :math:`N` probability measures + :math:`\mu_1, \dots, \mu_N` is a probability measure on the product + space of their respective supports, such that its marginals coincide, + in that order, with :math:`\mu_1, \dots, \mu_N`. + + push-forward map + Given a measurable mapping :math:`T` (e.g. a vector to vector map), + the push-forward measure of :math:`\mu` by :math:`T` denoted as + :math:`T\#\mu`, is the measure defined to be such that for any + measurable set :math:`B`, :math:`T\#\mu(B)=\mu(T^{-1}(B))`. Intuitively, + it is the measure obtained by applying the map :math:`T` to all points + described in :math:`\mu`. See also the + `Wikipedia definition `_. + + optimal transport + Mathematical theory used to describe and characterize efficient + transformations between probability measures. Such transformations can + be studied between continuous probability measures (e.g. densities) and + estimated using samples from probability measures. + + Sinkhorn algorithm + Fixed point iteration that solves the + :term:`entropy-regularized optimal transport` problem (EOT). + The Sinkhorn algorithm solves the EOT problem by seeking optimal + :math:`f`, :math:`g` :term:`dual Kantorovich potentials` (or + alternatively their parameterization as positive scaling vectors + :math:`u`, :math:`v`), rather than seeking + a :term:`coupling` :math:`P`. This is mostly for efficiency + (potentials and scalings have a ``n + m`` memory footprint, rather than + ``n m`` required to store :math:`P`). Note that an optimal coupling + :math:`P^{\star}` can be recovered from optimal potentials + :math:`f^{\star}`, :math:`g^{\star}` or scaling :math:`u^{\star}`, + :math:`v^{\star}`. + + .. math:: + + P^{\star} = \exp\left(\frac{f^{\star}\mathbf{1}_m^T + + \mathbf{1}_n g^{*T}-C}{\varepsilon}\right) \text{ or } P^{\star} + = \text{diag}(u^{\star}) K \text{diag}(v^{\star}) + + By default, the Sinkhorn algorithm solves this dual problem using block + coordinate ascent, i.e. devising an update for each :math:`f` and + :math:`g` (resp. :math:`u` and :math:`v`) that cancels their respective + gradients, one at a time. + + transport map + A function :math:`T` that associates to each point :math:`x` in the + support of a source distribution :math:`\mu` another point :math:`T(x)` + in the support of a target distribution :math:`\nu`, which must + satisfy a :term:`push-forward map` constraint :math:`T\#\mu = \nu`. + + transport plan + A :term:`coupling` (either in matrix or joint density form), + quantifying the strength of association between any point :math:`x`` in + the source distribution :math:`\mu` and target point :math:`y`` in the + :math:`\nu` distribution. + + transportation polytope + Given two probability vectors :math:`a,b` of non-negative weights of + respective size ``n`` and ``m``, summing each to :math:`1`, the + transportation polytope is the set of matrices + + .. math:: + + U(a,b):= \{P \in \mathbb{R}^{n\times m} | , + P\mathbf{1}_m = a, P^T\mathbf{1}_n=b \}. + + twist condition + Given a :term:`ground cost` function :math:`c(x, y)` taking two input + vectors, this refers to the requirement that at any given point + :math:`x`, the map :math:`y \rightarrow \nabla_1 c(x, y)` be invertible. + Although not necessary, this condition simplifies many proofs when + proving the existence of optimal :term:`transport map`. + + unbalanced + A generalization of the OT problem defined to bring more flexibility to + optimal transport computations. Such a generalization arises when + considering unnormalized probability distributions on the product space + of the supports :math:`\mu` and :math:`\nu`, without requiring that its + marginal coincides exactly with :math:`\mu` and :math:`\nu`. + + unrolling + Automatic differentiation technique to compute the vector-Jacobian + product of the minimizer of an optimization procedure by treating the + iterations (used to converge from an initial point) as layers in a + computational graph, and computing its differential using reverse-order + differentiation. + + Wasserstein distance + Distance between two probability functions parameterized by a + :term:`ground cost` function that is equal to the optimal objective + reached when solving the :term:`Kantorovich problem`. Such a distance + is truly a distance (in the sense that it satisfies all 3 + `metric axioms `_), + as long as the :term:`ground cost` is itself a distance to a power + :math:`p\leq 1`, and the :math:`1/p` power of the objective is taken. diff --git a/docs/index.rst b/docs/index.rst index 407552240..0c12df3d7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -92,10 +92,12 @@ Packages solution. - :mod:`ott.experimental` lists tools whose API is not mature yet to be included in the main toolbox, with changes expected in the near future, but which might - still prove useful for users. This includes at the moment the multimarginal - Sinkhorn solver class :class:`~ott.experimental.mmsinkhorn.MMSinkhon`. -- :mod:`ott.neural` provides tools to parameterize optimal transport maps, - couplings or conditional probabilities as neural networks. + still prove useful for users. This includes at the moment the + :class:`~ott.solvers.linear.mmsinkhorn.MMSinkhorn` solver class to compute an + optimal :term:`multimarginal coupling` +- :mod:`ott.neural` provides tools to parameterize and optimal + :term:`transport map`, :term:`coupling` or conditional probabilities as + neural networks. - :mod:`ott.tools` provides an interface to exploit OT solutions, as produced by solvers from the :mod:`ott.solvers` module. Such tasks include computing approximations to Wasserstein distances :cite:`genevay:18,sejourne:19`, @@ -130,6 +132,7 @@ Packages :maxdepth: 1 :caption: References + glossary bibliography contributing diff --git a/docs/math.rst b/docs/math.rst index 960ac9e78..4d66bdb5c 100644 --- a/docs/math.rst +++ b/docs/math.rst @@ -10,8 +10,8 @@ that can be automatically differentiated, and which might be of more general interest to other `JAX` users. :mod:`ott.math.matrix_square_root` contains an implementation of the matrix square-root using the Newton-Schulz iterations. That implementation is -itself differentiable using either implicit differentiation or unrolling of the -updates of these iterations. +itself differentiable using either :term:`implicit differentiation` or +:term:`unrolling` of the updates of these iterations. :mod:`ott.math.utils` contains various low-level routines re-implemented for their usage in `JAX`. Of particular interest are the custom jvp/vjp re-implementations for `logsumexp` and `norm` that have a behavior that differs, diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index 106a26bfb..50337be19 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -41,6 +41,7 @@ backpropagation barycenter barycenters barycentric +bijective binarized boolean centroids @@ -60,6 +61,7 @@ dimensionality discretization discretize downweighted +dualize duals eigendecomposition elementwise @@ -91,6 +93,8 @@ linearized logit macOS methylation +minimizer +minimizers multimarginal neuroimaging normed @@ -106,6 +110,7 @@ piecewise pluripotent polymatching polynomials +polytope positivity postfix potentials @@ -127,6 +132,7 @@ regularizer regularizers reimplementation renormalize +reparameterization reproducibility rescale rescaled @@ -156,6 +162,7 @@ thresholding transcriptome undirected univariate +unnormalized unscaled url vectorized diff --git a/pyproject.toml b/pyproject.toml index 2bec8e3b7..07685dc53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ authors = [ dependencies = [ "jax>=0.4.0", "jaxopt>=0.8", - "lineax>=0.0.5", + "lineax>=0.0.7", "numpy>=1.20.0", "typing_extensions; python_version <= '3.9'", ] @@ -162,13 +162,15 @@ ignore_directives = [ "autosummary", "automodule", "autoclass", - "bibliography" + "bibliography", + "glossary", ] ignore_roles = [ "class", "doc", "mod", "cite", + "term", ] [tool.doc8] diff --git a/src/ott/neural/methods/expectile_neural_dual.py b/src/ott/neural/methods/expectile_neural_dual.py index 28da42f17..da2e13638 100644 --- a/src/ott/neural/methods/expectile_neural_dual.py +++ b/src/ott/neural/methods/expectile_neural_dual.py @@ -144,8 +144,8 @@ class ExpectileNeuralDual: It solves the dual optimal transport problem for a specified cost function :math:`c(x, y)` between two measures :math:`\alpha` and :math:`\beta` in :math:`d`-dimensional Euclidean space with additional regularization on - Kantorovich potentials. The expectile regularization enforces binding - conditions on the learning dual potentials :math:`f` and :math:`g`. + :term:`dual Kantorovich potentials`. The expectile regularization enforces + binding conditions on the learning dual potentials :math:`f` and :math:`g`. The main optimization objective is .. math:: diff --git a/src/ott/neural/methods/monge_gap.py b/src/ott/neural/methods/monge_gap.py index cad282bc8..71e3446fb 100644 --- a/src/ott/neural/methods/monge_gap.py +++ b/src/ott/neural/methods/monge_gap.py @@ -123,7 +123,8 @@ def monge_gap_from_samples( W_{c, \varepsilon}(\frac{1}{n}\sum_i \delta_{x_i}, \frac{1}{n}\sum_i \delta_{y_i}) - where :math:`W_{c, \varepsilon}` is an entropy-regularized optimal transport + where :math:`W_{c, \varepsilon}` is an + :term:`entropy-regularized optimal transport` cost, the :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`. Args: diff --git a/src/ott/solvers/linear/implicit_differentiation.py b/src/ott/solvers/linear/implicit_differentiation.py index fbf98ce81..edc6544df 100644 --- a/src/ott/solvers/linear/implicit_differentiation.py +++ b/src/ott/solvers/linear/implicit_differentiation.py @@ -32,7 +32,11 @@ @utils.register_pytree_node class ImplicitDiff: - """Implicit differentiation of Sinkhorn algorithm. + """Implicit differentiation of the :term:`Sinkhorn algorithm`. + + This class encapsulates a few parameters and methods used to differentiate + the output of OT solvers w.r.t. relevant input variables such as point clouds + or cost functions. Args: solver: Callable to compute the solution to a linear problem. The callable @@ -54,17 +58,17 @@ class ImplicitDiff: symmetric: flag used to figure out whether the linear system solved in the implicit function theorem is symmetric or not. This happens when ``tau_a==tau_b``, and when ``a == b``, or the precondition_fun - is the identity. The flag is False by default, and is also tested against - ``tau_a==tau_b``. It needs to be set manually by the user in the more - favorable case where the system is guaranteed to be symmetric. + is the identity. The flag is :obj:`False` by default, and is also tested + against ``tau_a==tau_b``. It needs to be set manually by the user in the + more favorable case where the system is guaranteed to be symmetric. precondition_fun: Function used to precondition, on both sides, the linear system derived from first-order conditions of the regularized OT problem. That linear system typically involves an equality between marginals (or - simple transform of these marginals when the problem is unbalanced) and - another function of the potentials. When that function is specified, that - function is applied on both sides of the equality, before being further - differentiated to provide the Jacobians needed for implicit function - theorem differentiation. + simple transform of these marginals when the problem is :term:`unbalanced` + ) and another function of the potentials. When that function is specified, + that function is applied on both sides of the equality, before being + further differentiated to provide the Jacobians needed for + :term`implicit differentiation`. """ solver: Optional[Solver_t] = None @@ -80,10 +84,11 @@ def solve( g: jnp.ndarray, lse_mode: bool, ) -> jnp.ndarray: - r"""Apply minus inverse of [hessian ``reg_ot_cost`` w.r.t. ``f``, ``g``]. + r"""Apply minus inverse of Hessian of ``reg_ot_cost`` w.r.t. [``f``, ``g``]. - This function is used to carry out implicit differentiation of ``sinkhorn`` - outputs, notably optimal potentials ``f`` and ``g``. That differentiation + This function is used to carry out :term:`implicit differentiation` of + the outputs of the :term:`Sinkhorn algorithm`, notably + :term:`dual Kantorovich potentials` ``f`` and ``g``. That differentiation requires solving a linear system, using (and inverting) the Jacobian of (preconditioned) first-order conditions w.r.t. the reg-OT problem. @@ -112,9 +117,9 @@ def solve( \log`, as proposed in :cite:`cuturi:20a`. In both cases :math:`A` and :math:`D` are diagonal matrices, equal to the - row and - column marginals respectively, multiplied by the derivatives of :math:`h` - evaluated at those marginals, corrected (if handling the unbalanced case) + row and column marginals of the :term:`coupling` respectively, + multiplied by the derivatives of :math:`h` evaluated at those marginals, + corrected (if handling the :term:`unbalanced` case) by the second derivative of the part of the objective that ties potentials to the marginals (terms in ``phi_star``). When :math:`h` is the identity, :math:`B` and :math:`B^T` are equal respectively to the OT matrix and its diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 760f33f84..960061331 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -508,117 +508,8 @@ def g(self) -> jnp.ndarray: class Sinkhorn: r"""Sinkhorn solver. - The Sinkhorn algorithm is a fixed point iteration that solves a regularized - optimal transport (reg-OT) problem between two measures. - The optimization variables are a pair of vectors (called potentials, or - scalings when parameterized as exponential of the former). Calling this - function returns therefore a pair of optimal vectors. In addition to these, - it also returns the objective value achieved by these optimal vectors; - a vector of size ``max_iterations/inner_iterations`` that records the vector - of values recorded to monitor convergence, throughout the execution of the - algorithm (padded with `-1` if convergence happens before), as well as a - boolean to signify whether the algorithm has converged within the number of - iterations specified by the user. - - The reg-OT problem is specified by two measures, of respective sizes ``n`` and - ``m``. From the viewpoint of the ``sinkhorn`` function, these two measures are - only seen through a triplet (``geom``, ``a``, ``b``), where ``geom`` is a - ``Geometry`` object, and ``a`` and ``b`` are weight vectors of respective - sizes ``n`` and ``m``. Starting from two initial values for those potentials - or scalings (both can be defined by the user by passing value in - ``init_dual_a`` or ``init_dual_b``), the Sinkhorn algorithm will use - elementary operations that are carried out by the ``geom`` object. - - Math: - Given a geometry ``geom``, which provides a cost matrix :math:`C` with its - regularization parameter :math:`\varepsilon`, (or a kernel matrix :math:`K`) - the reg-OT problem consists in finding two vectors `f`, `g` of size ``n``, - ``m`` that maximize the following criterion. - - .. math:: - - \arg\max_{f, g}{- \langle a, \phi_a^{*}(-f) \rangle - \langle b, - \phi_b^{*}(-g) \rangle - \varepsilon \langle e^{f/\varepsilon}, - e^{-C/\varepsilon} e^{g/\varepsilon}} \rangle - - where :math:`\phi_a(z) = \rho_a z(\log z - 1)` is a scaled entropy, and - :math:`\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}`, its Legendre transform - :cite:`sejourne:19`. - - That problem can also be written, instead, using positive scaling vectors - `u`, `v` of size ``n``, ``m``, handled with the kernel - :math:`K := e^{-C/\varepsilon}`, - - .. math:: - - \arg\max_{u, v >0} - \langle a,\phi_a^{*}(-\varepsilon\log u) \rangle + - \langle b, \phi_b^{*}(-\varepsilon\log v) \rangle - \langle u, K v \rangle - - Both of these problems corresponds, in their *primal* formulation, to - solving the unbalanced optimal transport problem with a variable matrix - :math:`P` of size ``n`` x ``m``: - - .. math:: - - \arg\min_{P>0} \langle P,C \rangle +\varepsilon \text{KL}(P | ab^T) - + \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n - | b) - - where :math:`KL` is the generalized Kullback-Leibler divergence. - - The very same primal problem can also be written using a kernel :math:`K` - instead of a cost :math:`C` as well: - - .. math:: - - \arg\min_{P} \varepsilon \text{KL}(P|K) - + \rho_a \text{KL}(P\mathbf{1}_m | a) + - \rho_b \text{KL}(P^T \mathbf{1}_n | b) - - The *original* OT problem taught in linear programming courses is recovered - by using the formulation above relying on the cost :math:`C`, and letting - :math:`\varepsilon \rightarrow 0`, and :math:`\rho_a, \rho_b \rightarrow - \infty`. - In that case the entropy disappears, whereas the :math:`KL` regularization - above become constraints on the marginals of :math:`P`: This results in a - standard min cost flow problem. This problem is not handled for now in this - toolbox, which focuses exclusively on the case :math:`\varepsilon > 0`. - - The *balanced* regularized OT problem is recovered for finite - :math:`\varepsilon > 0` but letting :math:`\rho_a, \rho_b \rightarrow - \infty`. This problem can be shown to be equivalent to a matrix scaling - problem, which can be solved using the Sinkhorn fixed-point algorithm. - To handle the case :math:`\rho_a, \rho_b \rightarrow \infty`, the - ``sinkhorn`` function uses parameters ``tau_a`` and ``tau_b`` equal - respectively to :math:`\rho_a /(\varepsilon + \rho_a)` and - :math:`\rho_b / (\varepsilon + \rho_b)` instead. Setting either of these - parameters to 1 corresponds to setting the corresponding - :math:`\rho_a, \rho_b` to :math:`\infty`. - - The Sinkhorn algorithm solves the reg-OT problem by seeking optimal - :math:`f`, :math:`g` potentials (or alternatively their parameterization - as positive scaling vectors :math:`u`, :math:`v`), rather than solving the - primal problem in :math:`P`. This is mostly for efficiency (potentials and - scalings have a ``n + m`` memory footprint, rather than ``n m`` required - to store `P`). This is also because both problems are, in fact, equivalent, - since the optimal transport :math:`P^{\star}` can be recovered from - optimal potentials :math:`f^{\star}`, :math:`g^{\star}` or scaling - :math:`u^{\star}`, :math:`v^{\star}`, using the geometry's cost or kernel - matrix respectively: - - .. math:: - - P^{\star} = \exp\left(\frac{f^{\star}\mathbf{1}_m^T + \mathbf{1}_n g^{*T}- - C}{\varepsilon}\right) \text{ or } P^{\star} = \text{diag}(u^{\star}) K - \text{diag}(v^{\star}) - - By default, the Sinkhorn algorithm solves this dual problem in :math:`f, g` - or :math:`u, v` using block coordinate ascent, i.e. devising an update for - each :math:`f` and :math:`g` (resp. :math:`u` and :math:`v`) that cancels - their respective gradients, one at a time. These two iterations are repeated - ``inner_iterations`` times, after which the norm of these gradients will be - evaluated and compared with the ``threshold`` value. The iterations are then - repeated as long as that error exceeds ``threshold``. + The :term:`Sinkhorn algorithm` is a fixed point iteration that solves a + regularized optimal transport (reg-OT) problem between two measures. Note on Sinkhorn updates: The boolean flag ``lse_mode`` sets whether the algorithm is run in either: @@ -659,50 +550,47 @@ class Sinkhorn: can only be carried out using implicit differentiation, and that all momentum related parameters are ignored. - The ``parallel_dual_updates`` flag is set to ``False`` by default. In that - setting, ``g_v`` is first updated using the latest values for ``f_u`` and - ``g_v``, before proceeding to update ``f_u`` using that new value for - ``g_v``. When the flag is set to ``True``, both ``f_u`` and ``g_v`` are - updated simultaneously. Note that setting that choice to ``True`` requires - using some form of averaging (e.g. ``momentum=0.5``). Without this, and on - its own ``parallel_dual_updates`` won't work. + The ``parallel_dual_updates`` flag is set to :obj:`False` by default. In + that setting, ``g_v`` is first updated using the latest values for ``f_u`` + and ``g_v``, before proceeding to update ``f_u`` using that new value for + ``g_v``. When the flag is set to :obj:`True`, both ``f_u`` and ``g_v`` are + updated simultaneously. Note that setting that choice to :obj:`True` + requires using some form of averaging (e.g. ``momentum=0.5``). Without this, + and on its own ``parallel_dual_updates`` won't work. Differentiation: The optimal solutions ``f`` and ``g`` and the optimal objective (``reg_ot_cost``) outputted by the Sinkhorn algorithm can be differentiated w.r.t. relevant inputs ``geom``, ``a`` and ``b``. In the default setting, - implicit differentiation of the optimality conditions (``implicit_diff`` - not equal to ``None``), this has two consequences, treating ``f`` and ``g`` - differently from ``reg_ot_cost``. + the algorithm uses :term:`implicit differentiation` of the optimality + conditions (``implicit_diff`` not equal to ``None``). This has two + consequences: - The termination criterion used to stop Sinkhorn (cancellation of gradient of objective w.r.t. ``f_u`` and ``g_v``) is used to differentiate - ``f`` and ``g``, given a change in the inputs. These changes are computed - by solving a linear system. The arguments starting with - ``implicit_solver_*`` allow to define the linear solver that is used, and - to control for two types or regularization (we have observed that, - depending on the architecture, linear solves may require higher ridge - parameters to remain stable). The optimality conditions in Sinkhorn can be - analyzed as satisfying a ``z=z'`` condition, which are then + the :term:`dual Kantorovich potentials` ``f`` and ``g``, given a change in + the inputs. These changes are computed by solving a linear system. The + optimality conditions of the :term:`entropy-regularized optimal transport` + problem can be analyzed as satisfying a ``z=z'`` condition, which are then differentiated. It might be beneficial (e.g., as in :cite:`cuturi:20a`) to use a preconditioning function ``precondition_fun`` to differentiate instead ``h(z) = h(z')``. - The objective ``reg_ot_cost`` returned by Sinkhorn uses the so-called - envelope (or Danskin's) theorem. In that case, because it is assumed that - the gradients of the dual variables ``f_u`` and ``g_v`` w.r.t. dual - objective are zero (reflecting the fact that they are optimal), small - variations in ``f_u`` and ``g_v`` due to changes in inputs (such as - ``geom``, ``a`` and ``b``) are considered negligible. As a result, - ``stop_gradient`` is applied on dual variables ``f_u`` and ``g_v`` when - evaluating the ``reg_ot_cost`` objective. Note that this approach is + :term:`envelope theorem` (a.k.a. Danskin's theorem). In that case, + because it is assumed that the gradients of the dual variables ``f_u`` and + ``g_v`` w.r.t. dual objective are zero (reflecting the fact that they are + optimal), small variations in ``f_u`` and ``g_v`` due to changes in inputs + (such as ``geom``, ``a`` and ``b``) are considered negligible. As a + result, ``stop_gradient`` is applied on dual variables ``f_u`` and ``g_v`` + when evaluating the ``reg_ot_cost`` objective. Note that this approach is `invalid` when computing higher order derivatives. In that case the - ``use_danskin`` flag must be set to ``False``. + ``use_danskin`` flag must be set to :obj:`False`. An alternative yet more costly way to differentiate the outputs of the - Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation - of the Sinkhorn loop. This is possible because Sinkhorn iterations are - wrapped in a custom fixed point iteration loop, defined in + Sinkhorn iterations is to use :term:`unrolling`, i.e. reverse mode + differentiation of the Sinkhorn loop. This is possible because Sinkhorn + iterations are wrapped in a custom fixed point iteration loop, defined in ``fixed_point_loop``, rather than a standard while loop. This is to ensure the end result of this fixed point loop can also be differentiated, if needed, using standard JAX operations. To ensure differentiability, @@ -712,8 +600,8 @@ class Sinkhorn: ``inner_iterations`` at a time. Note: - * The Sinkhorn algorithm may not converge within the maximum number of - iterations for possibly several reasons: + * The :term:`Sinkhorn algorithm` may not converge within the maximum number + of iterations for possibly several reasons: 1. the regularizer (defined as ``epsilon`` in the geometry ``geom`` object) is too small. Consider either switching to ``lse_mode=True`` @@ -727,31 +615,34 @@ class Sinkhorn: 3. OOMs issues may arise when storing either cost or kernel matrices that are too large in ``geom``. In the case where, the ``geom`` geometry is a ``PointCloud``, some of these issues might be solved by setting the - ``online`` flag to ``True``. This will trigger a re-computation on the - fly of the cost/kernel matrix. + ``online`` flag to :obj:`True`. This will trigger a re-computation on + the fly of the cost/kernel matrix. * The weight vectors ``a`` and ``b`` can be passed on with coordinates that have zero weight. This is then handled by relying on simple arithmetic for ``inf`` values that will likely arise (due to :math:`\log 0` when - ``lse_mode`` is ``True``, or divisions by zero when ``lse_mode`` is - ``False``). Whenever that arithmetic is likely to produce ``NaN`` values - (due to ``-inf * 0``, or ``-inf - -inf``) in the forward pass, we use - ``jnp.where`` conditional statements to carry ``inf`` rather than ``NaN`` - values. In the reverse mode differentiation, the inputs corresponding to - these 0 weights (a location `x`, or a row in the corresponding cost/kernel - matrix), and the weight itself will have ``NaN`` gradient values. This is - reflects that these gradients are undefined, since these points were not - considered in the optimization and have therefore no impact on the output. + ``lse_mode`` is :obj:`True`, or divisions by zero when ``lse_mode`` is + :obj:`False`). Whenever that arithmetic is likely to produce ``NaN`` + values (due to ``-inf * 0``, or ``-inf - -inf``) in the forward pass, we + use ``jnp.where`` conditional statements to carry ``inf`` rather than + ``NaN`` values. In reverse mode differentiation, the inputs corresponding + to these 0 weights (a location `x`, or a row in the corresponding + cost/kernel matrix), and the weight itself will have ``NaN`` gradient + values. This reflects that these gradients are undefined, since these + points were not considered in the optimization and have therefore no + impact on the output. Args: - lse_mode: ``True`` for log-sum-exp computations, ``False`` for kernel + lse_mode: :obj:`True` for log-sum-exp computations, :obj:`False` for kernel multiplication. threshold: tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the - current primal solution when either or both tau_a and tau_b are 1.0 - (balanced or semi-balanced problem), or the relative change between two - successive solutions in the unbalanced case. - norm_error: power used to define p-norm of error for marginal/target. + current primal solution when either or both ``tau_a`` and ``tau_b`` are + :math:`1.0` (balanced or semi-balanced problem), or the relative change + between two successive solutions in the unbalanced case. + norm_error: power used to define the :math:`p`-norm used to quantify + the magnitude of the gradients. This criterion is used to terminate the + algorithm. inner_iterations: the Sinkhorn error is not recomputed at each iteration but every ``inner_iterations`` instead. min_iterations: the minimum number of Sinkhorn iterations carried @@ -760,26 +651,30 @@ class Sinkhorn: ``max_iterations`` is equal to ``min_iterations``, Sinkhorn iterations are run by default using a :func:`jax.lax.scan` loop rather than a custom, unroll-able :func:`jax.lax.while_loop` that monitors convergence. - In that case the error is not monitored and the ``converged`` - flag will return ``False`` as a consequence. - momentum: Momentum instance. - anderson: AndersonAcceleration instance. - implicit_diff: instance used to solve implicit differentiation. Unrolls - iterations if None. - parallel_dual_updates: updates potentials or scalings in parallel if True, - sequentially (in Gauss-Seidel fashion) if False. + In that case the error is only computed at the last iteration. + momentum: :class:`~ott.solvers.linear.acceleration.Momentum` instance. + anderson: :class:`~ott.solvers.linear.acceleration.AndersonAcceleration` + instance. + implicit_diff: + :class:`~ott.solvers.linear.implicit_differentiation.ImplicitDiff` + instance used to parameterize the linear solvers used in + :term:`implicit differentiation`. Tha algorithm uses :term:`unrolling` of + iterations if ``None``. + parallel_dual_updates: updates potentials or scalings in parallel if + :obj:`True`, sequentially (in Gauss-Seidel fashion) if :obj:`False`. recenter_potentials: Whether to re-center the dual potentials. If the problem is balanced, the ``f`` potential is zero-centered for numerical stability. Otherwise, use the approach of :cite:`sejourne:22` to achieve faster convergence. Only used when ``lse_mode = True`` and ``tau_a < 1`` and ``tau_b < 1``. - use_danskin: when ``True``, it is assumed the entropy regularized cost - is evaluated using optimal potentials that are frozen, i.e. whose - gradients have been stopped. This is useful when carrying out first order - differentiation, and is only valid (as with ``implicit_differentiation``) - when the algorithm has converged with a low tolerance. - initializer: how to compute the initial potentials/scalings. This refers to - a few possible classes implemented following the template in + use_danskin: when :obj:`True`, it is assumed the + :term:`entropy-regularized optimal transport` cost + is evaluated using :term:`dual Kantorovich potentials` that are frozen, + i.e. whose gradients have been stopped. This is useful when carrying out + first order differentiation, and is only valid when the algorithm has + converged with a low tolerance. + initializer: method to compute the initial potentials/scalings. This refers + to a few possible classes implemented following the template in :class:`~ott.initializers.linear.SinkhornInitializer`. progress_fn: callback function which gets called during the Sinkhorn iterations, so the user can display the error at each iteration, @@ -1070,7 +965,7 @@ def output_from_state( The flag ``use_danskin`` controls whether that assumption is made. By default, that flag is set to the value of ``implicit_differentiation`` if not specified. If you wish to compute derivatives of order 2 and above, - set ``use_danskin`` to ``False``. + set ``use_danskin`` to :obj:`False`. Args: ot_prob: the transport problem. diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 3eb23a89b..142383d68 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -154,7 +154,7 @@ def update( # noqa: D102 @jax.tree_util.register_pytree_node_class class GromovWasserstein(was_solver.WassersteinSolver): - """Gromov-Wasserstein solver :cite:`peyre:16`. + """Entropic Gromov-Wasserstein solver :cite:`peyre:16`. .. seealso:: Low-rank Gromov-Wasserstein :cite:`scetbon:23` is implemented in diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index c35b01c4e..d96025b58 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -511,6 +511,7 @@ def loss_from_potential(a: jnp.ndarray, x: jnp.ndarray, implicit: bool): loss_back = jax.jit( jax.grad(lambda a, x: loss_from_potential(a, x, False), argnums=arg) ) + g_back = loss_back(a, x) back_dif = jnp.sum(g_back * (delta_a if arg == 0 else delta_x))