Skip to content

Commit

Permalink
use lineax to solve linear system in implicit diff (#370)
Browse files Browse the repository at this point in the history
* use lineax to solve linear system in implicit diff

* doc

* fix

* make lineax solvers optional, add a jax default

* pydoc

* pydoc

* pydoc

* pydoc

* pydoc

* pydoc

* selective tests

* fixing another test

* reintroduce ridge for jax solvers, to pass tests

* fix again soft-sort using ridge

* pydoc

* pydoc.

* lint

* increase epsilon to ensure no_precond works.

* readded backprop in test hessian + comments

* F401 in unused import.

* change tolerance for kernel mode

* remove finite diff / backprop test.

* adding lineax in __init__ for docs.

* adding back try import in test.

* docs + test_back

* mod back

* Update readthedocs.yml

* Remove `contextlib`

* Fix wrong file name

---------

Co-authored-by: Michal Klein <[email protected]>
  • Loading branch information
marcocuturi and michalk8 authored Jun 20, 2023
1 parent 31df701 commit 428316c
Show file tree
Hide file tree
Showing 10 changed files with 405 additions and 195 deletions.
13 changes: 8 additions & 5 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
version: 2
build:
image: latest
os: ubuntu-22.04
tools:
python: '3.10'

sphinx:
builder: html
configuration: docs/conf.py
fail_on_warning: false

python:
version: 3.8
install:
- method: pip
path: .
extra_requirements:
- docs
extra_requirements: [docs]
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"lineax": ("https://docs.kidger.site/lineax/", None),
"flax": ("https://flax.readthedocs.io/en/latest/", None),
"scikit-sparse": ("https://scikit-sparse.readthedocs.io/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
Expand Down
2 changes: 2 additions & 0 deletions docs/solvers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ Implicit Differentiation
:toctree: _autosummary

implicit_differentiation.ImplicitDiff
implicit_differentiation.solve_jax_cg
lineax_implicit.solve_lineax
190 changes: 125 additions & 65 deletions docs/tutorials/notebooks/Hessians.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"numpy>=1.18.4, !=1.23.0",
"flax>=0.5.2",
"optax>=0.1.1",
"lineax>=0.0.1; python_version >= '3.9'"
]
keywords = [
"optimal transport",
Expand Down Expand Up @@ -73,6 +74,7 @@ test = [
"scikit-learn>=1.0",
# tslearn needs numba, which isn't supported for 3.11
"tslearn>=0.5; python_version < '3.11'",
"lineax; python_version >= '3.9'",
]
docs = [
"sphinx>=4.0",
Expand Down
1 change: 1 addition & 0 deletions src/ott/solvers/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import (
acceleration,
continuous_barycenter,
Expand Down
163 changes: 95 additions & 68 deletions src/ott/solvers/linear/implicit_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -23,6 +23,10 @@
if TYPE_CHECKING:
from ott.problems.linear import linear_problem

LinOp_t = Callable[[jnp.ndarray], jnp.ndarray]
Solver_t = Callable[[LinOp_t, jnp.ndarray, Optional[LinOp_t], bool],
jnp.ndarray]

__all__ = ["ImplicitDiff"]


Expand All @@ -31,16 +35,28 @@ class ImplicitDiff:
"""Implicit differentiation of Sinkhorn algorithm.
Args:
solver_fun: Callable, should return (solution, ...)
ridge_kernel: promotes zero-sum solutions. only used if tau_a = tau_b = 1.0
ridge_identity: handles rank deficient transport matrices (this happens
typically when rows/cols in cost/kernel matrices are collinear, or,
equivalently when two points from either measure are close).
solver: Callable to compute the solution to a linear problem. The callable
expects a linear function, a vector, optionally another linear function
that implements the transpose of that function, and a boolean flag to
specify symmetry. This solver is by default one of :class:`lineax.CG` or
:class:`lineax.NormalCG` solvers, if the package can be imported, as
described in :func:`~ott.solvers.linear.lineax_implicit.solve_lineax`.
The :mod:`jax` alternative is described in
:func:`~ott.solvers.linear.implicit_differentiation.solve_jax_cg`.
Note that `lineax` solvers handle better poorly conditioned problems,
which arise typically when differentiating the solutions of balanced OT
problems (when ``tau_a==tau_b==1.0``). Relying on
:func:`~ott.solvers.linear.implicit_differentiation.solve_jax_cg`
for such cases might require hand-tuning ridge parameters,
in particular ``ridge_kernel`` and ``ridge_identity`` as described in its
doc. These parameters can be passed using ``solver_kwargs`` below.
solver_kwargs: keyword arguments passed on to the solver.
symmetric: flag used to figure out whether the linear system solved in the
implicit function theorem is symmetric or not. This happens when either
``a == b`` or the precondition_fun is the identity. False by default, and,
at the moment, needs to be set manually by the user in the more favorable
case where the system is guaranteed to be symmetric.
implicit function theorem is symmetric or not. This happens when
``tau_a==tau_b``, and when ``a == b``, or the precondition_fun
is the identity. The flag is False by default, and is also tested against
``tau_a==tau_b``. It needs to be set manually by the user in the more
favorable case where the system is guaranteed to be symmetric.
precondition_fun: Function used to precondition, on both sides, the linear
system derived from first-order conditions of the regularized OT problem.
That linear system typically involves an equality between marginals (or
Expand All @@ -51,17 +67,18 @@ class ImplicitDiff:
theorem differentiation.
"""

solver_fun: Callable[[jnp.ndarray, jnp.ndarray],
Tuple[jnp.ndarray, ...]] = jax.scipy.sparse.linalg.cg
ridge_kernel: float = 0.0
ridge_identity: float = 0.0
solver: Optional[Solver_t] = None
solver_kwargs: Optional[Dict[str, Any]] = None
symmetric: bool = False
precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None

def solve(
self, gr: Tuple[jnp.ndarray,
jnp.ndarray], ot_prob: "linear_problem.LinearProblem",
f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool
self,
gr: Tuple[jnp.ndarray, jnp.ndarray],
ot_prob: "linear_problem.LinearProblem",
f: jnp.ndarray,
g: jnp.ndarray,
lse_mode: bool,
) -> jnp.ndarray:
r"""Apply minus inverse of [hessian ``reg_ot_cost`` w.r.t. ``f``, ``g``].
Expand Down Expand Up @@ -117,20 +134,9 @@ def solve(
instantiate the Schur complement of the first or of the second diagonal
block.
In either case, the Schur complement is rank deficient, with a 0 eigenvalue
for the vector of ones in the balanced case, which is why we add a ridge on
that subspace to enforce solutions have zero sum.
The Schur complement can also be rank deficient if two lines or columns of T
are collinear. This will typically happen it two rows or columns of the cost
or kernel matrix are numerically close. To avoid this, we add a more global
``ridge_identity * z`` regularizer to achieve better conditioning.
These linear systems are solved using the user defined ``solver_fun``,
which is set by default to ``cg``. When the system is symmetric (as detected
by the corresponding flag ``symmetric``), ``cg`` is applied directly. When
it is not, normal equations are used (i.e. the Schur complement is
multiplied by its transpose before solving the system).
These linear systems are solved using the user-defined ``solver``, using
by default :mod:`lineax` solvers when available, or falling back on
:mod:`jax` when not.
Args:
gr: 2-tuple, (vector of size ``n``, vector of size ``m``).
Expand All @@ -142,17 +148,19 @@ def solve(
Returns:
A tuple of two vectors, of the same size as ``gr``.
"""
solver = _get_solver() if self.solver is None else self.solver
solver_kwargs = {} if self.solver_kwargs is None else self.solver_kwargs
geom = ot_prob.geom
marginal_a, marginal_b, app_transport = (
ot_prob.get_transport_functions(lse_mode)
)

# elementwise vmap apply of derivative of precondition_fun. No vmapping
# can be problematic here.
if self.precondition_fun is None:
precond_fun = lambda x: geom.epsilon * jnp.log(x)
symmetric = False
else:
precond_fun = self.precondition_fun
symmetric = self.symmetric

derivative = jax.vmap(jax.grad(precond_fun))

n, m = geom.shape
Expand All @@ -164,7 +172,7 @@ def solve(
f, g, z * derivative(marginal_a(f, g)), axis=0
) / geom.epsilon

if not self.symmetric:
if not symmetric:
vjp_fgt = lambda z: app_transport(
f, g, z, axis=0
) * derivative(marginal_b(f, g)) / geom.epsilon
Expand All @@ -184,52 +192,32 @@ def solve(
ot_prob.b, g, ot_prob.tau_b, geom.epsilon, derivative
)
)

n, m = geom.shape
# Remove ridge on kernel space if problem is balanced.
ridge_kernel = jnp.where(ot_prob.is_balanced, self.ridge_kernel, 0.0)

# TODO(cuturi) consider materializing linear operator schur if size allows.
# Forks on using Schur complement of either A or D, depending on size.
if n > m: # if n is bigger, run m x m linear system.
inv_vjp_ff = lambda z: z / diag_hess_a
vjp_gg = lambda z: z * diag_hess_b
schur_ = lambda z: vjp_gg(z) - vjp_gf(inv_vjp_ff(vjp_fg(z)))
res = gr[1] - vjp_gf(inv_vjp_ff(gr[0]))

if self.symmetric:
schur = lambda z: (
schur_(z) + ridge_kernel * jnp.sum(z) + self.ridge_identity * z
)
else:
schur = lambda z: vjp_gg(z) - vjp_gf(inv_vjp_ff(vjp_fg(z)))
if not symmetric:
schur_t = lambda z: vjp_gg(z) - vjp_fgt(inv_vjp_ff(vjp_gft(z)))
res = schur_t(res)
schur = lambda z: (
schur_t(schur_(z)) + ridge_kernel * jnp.sum(z) + self.ridge_identity
* z
)

sch = self.solver_fun(schur, res)[0]
else:
schur_t = None
res = gr[1] - vjp_gf(inv_vjp_ff(gr[0]))
sch = solver(schur, res, schur_t, symmetric, **solver_kwargs)
vjp_gr_f = inv_vjp_ff(gr[0] - vjp_fg(sch))
vjp_gr_g = sch
else:
vjp_ff = lambda z: z * diag_hess_a
inv_vjp_gg = lambda z: z / diag_hess_b
schur_ = lambda z: vjp_ff(z) - vjp_fg(inv_vjp_gg(vjp_gf(z)))
res = gr[0] - vjp_fg(inv_vjp_gg(gr[1]))
schur = lambda z: vjp_ff(z) - vjp_fg(inv_vjp_gg(vjp_gf(z)))

if self.symmetric:
schur = lambda z: (
schur_(z) + ridge_kernel * jnp.sum(z) + self.ridge_identity * z
)
else:
if not symmetric:
schur_t = lambda z: vjp_ff(z) - vjp_gft(inv_vjp_gg(vjp_fgt(z)))
res = schur_t(res)
schur = lambda z: (
schur_t(schur_(z)) + ridge_kernel * jnp.sum(z) + self.ridge_identity
* z
)

sch = self.solver_fun(schur, res)[0]
else:
schur_t = None
res = gr[0] - vjp_fg(inv_vjp_gg(gr[1]))
sch = solver(schur, res, schur_t, symmetric, **solver_kwargs)
vjp_gr_g = inv_vjp_gg(gr[1] - vjp_gf(sch))
vjp_gr_f = sch

Expand Down Expand Up @@ -295,3 +283,42 @@ def gradient(

def replace(self, **kwargs: Any) -> "ImplicitDiff": # noqa: D102
return dataclasses.replace(self, **kwargs)


def solve_jax_cg(
lin: LinOp_t,
b: jnp.ndarray,
lin_t: Optional[LinOp_t] = None,
symmetric: bool = False,
ridge_identity: float = 0.0,
ridge_kernel: float = 0.0,
**kwargs: Any
) -> jnp.ndarray:
"""Wrapper around JAX native linear solvers.
Args:
lin: Linear operator
b: vector. Returned `x` is such that `lin(x)=b`
lin_t: Linear operator, corresponding to transpose of `lin`.
symmetric: whether `lin` is symmetric.
ridge_kernel: promotes zero-sum solutions. Only use if `tau_a = tau_b = 1.0`
ridge_identity: handles rank deficient transport matrices (this happens
typically when rows/cols in cost/kernel matrices are collinear, or,
equivalently when two points from either measure are close).
kwargs: arguments passed to :func:`~jax.scipy.sparse.linalg.cg`
"""
op = lin if symmetric else lambda x: lin_t(lin(x))
if ridge_kernel > 0.0 or ridge_identity > 0.0:
lin_reg = lambda x: op(x) + ridge_kernel * jnp.sum(x) + ridge_identity * x
else:
lin_reg = op
return jax.scipy.sparse.linalg.cg(lin_reg, b, **kwargs)[0]


def _get_solver() -> Solver_t:
"""Get lineax solver when possible, default to jax.scipy else."""
try:
from ott.solvers.linear import lineax_implicit
return lineax_implicit.solve_lineax
except ImportError:
return solve_jax_cg
84 changes: 84 additions & 0 deletions src/ott/solvers/linear/lineax_implicit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax as lx
from jaxtyping import Array, Float, PyTree

_T = TypeVar("_T")
_FlatPyTree = tuple[list[_T], jtu.PyTreeDef]

__all__ = ["CustomTransposeLinearOperator"]


class CustomTransposeLinearOperator(lx.FunctionLinearOperator):
"""Implement a linear operator that can specify its transpose directly."""
fn: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]]
fn_t: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]]
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field()
input_structure_t: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field()
tags: frozenset[object]

def __init__(self, fn, fn_t, input_structure, input_structure_t, tags=()):
super().__init__(fn, input_structure, tags)
self.fn_t = eqx.filter_closure_convert(fn_t, input_structure_t)
self.input_structure_t = input_structure_t

def transpose(self):
"""Provide custom transposition operator from function."""
return lx.FunctionLinearOperator(self.fn_t, self.input_structure_t)


def solve_lineax(
lin: Callable,
b: jnp.ndarray,
lin_t: Optional[Callable] = None,
symmetric: Optional[bool] = False,
nonsym_solver: Optional[lx.AbstractLinearSolver] = None,
**kwargs: Any
) -> jnp.ndarray:
"""Wrapper around lineax solvers.
Args:
lin: Linear operator
b: vector. Returned `x` is such that `lin(x)=b`
lin_t: Linear operator, corresponding to transpose of `lin`.
symmetric: whether `lin` is symmetric.
nonsym_solver: :class:`~lineax.AbstractLinearSolver` used when handling non
symmetric cases. Note that :class:`~lineax.CG` is used by default in the
symmetric case.
kwargs: arguments passed to :mod:`~lineax.AbstractLinearSolver` linear
solver.
"""
input_structure = jax.eval_shape(lambda: b)
kwargs.setdefault("rtol", 1e-6)
kwargs.setdefault("atol", 1e-6)
if symmetric:
solver = lx.CG(**kwargs)
fn_operator = lx.FunctionLinearOperator(
lin, input_structure, tags=lx.positive_semidefinite_tag
)
return lx.linear_solve(fn_operator, b, solver).value
# In the nonsymmetric case, use NormalCG by default, but consider
# user defined choice of alternative lx solver.
solver_type = lx.NormalCG if nonsym_solver is None else nonsym_solver
solver = solver_type(**kwargs)
fn_operator = CustomTransposeLinearOperator(
lin, lin_t, input_structure, input_structure
)
return lx.linear_solve(fn_operator, b, solver).value
Loading

0 comments on commit 428316c

Please sign in to comment.