Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use lineax to solve linear system in implicit diff #370

Merged
merged 29 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e72b113
use lineax to solve linear system in implicit diff
marcocuturi Jun 13, 2023
e243d69
doc
marcocuturi Jun 13, 2023
cb56be0
fix
marcocuturi Jun 13, 2023
670fbf9
make lineax solvers optional, add a jax default
marcocuturi Jun 15, 2023
e11491b
pydoc
marcocuturi Jun 15, 2023
e170573
pydoc
marcocuturi Jun 15, 2023
eb438a4
pydoc
marcocuturi Jun 15, 2023
ac44e3c
pydoc
marcocuturi Jun 15, 2023
587fdf8
pydoc
marcocuturi Jun 15, 2023
bf84695
pydoc
marcocuturi Jun 15, 2023
200159c
selective tests
marcocuturi Jun 15, 2023
7883450
fixing another test
marcocuturi Jun 15, 2023
1ce4386
reintroduce ridge for jax solvers, to pass tests
marcocuturi Jun 16, 2023
91fbdc8
fix again soft-sort using ridge
marcocuturi Jun 16, 2023
d2b4793
pydoc
marcocuturi Jun 16, 2023
3abe03b
pydoc.
marcocuturi Jun 16, 2023
978a5a5
lint
marcocuturi Jun 16, 2023
82a181a
increase epsilon to ensure no_precond works.
marcocuturi Jun 16, 2023
0f450df
readded backprop in test hessian + comments
marcocuturi Jun 19, 2023
1c965cf
F401 in unused import.
marcocuturi Jun 19, 2023
da14527
change tolerance for kernel mode
marcocuturi Jun 19, 2023
4f35978
remove finite diff / backprop test.
marcocuturi Jun 19, 2023
cafdbe0
adding lineax in __init__ for docs.
marcocuturi Jun 19, 2023
c79c760
adding back try import in test.
marcocuturi Jun 19, 2023
60a3522
docs + test_back
marcocuturi Jun 19, 2023
2f9d8a0
mod back
marcocuturi Jun 19, 2023
c399996
Update readthedocs.yml
michalk8 Jun 20, 2023
0ad3b55
Remove `contextlib`
michalk8 Jun 20, 2023
293bf93
Fix wrong file name
michalk8 Jun 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
version: 2
build:
image: latest
os: ubuntu-22.04
tools:
python: '3.10'

sphinx:
builder: html
configuration: docs/conf.py
fail_on_warning: false

python:
version: 3.8
install:
- method: pip
path: .
extra_requirements:
- docs
extra_requirements: [docs]
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"lineax": ("https://docs.kidger.site/lineax/", None),
"flax": ("https://flax.readthedocs.io/en/latest/", None),
"scikit-sparse": ("https://scikit-sparse.readthedocs.io/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
Expand Down
2 changes: 2 additions & 0 deletions docs/solvers/linear.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ Implicit Differentiation
:toctree: _autosummary

implicit_differentiation.ImplicitDiff
implicit_differentiation.solve_jax_cg
lineax_implicit.solve_lineax
190 changes: 125 additions & 65 deletions docs/tutorials/notebooks/Hessians.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"numpy>=1.18.4, !=1.23.0",
"flax>=0.5.2",
"optax>=0.1.1",
"lineax>=0.0.1; python_version >= '3.9'"
]
keywords = [
"optimal transport",
Expand Down Expand Up @@ -73,6 +74,7 @@ test = [
"scikit-learn>=1.0",
# tslearn needs numba, which isn't supported for 3.11
"tslearn>=0.5; python_version < '3.11'",
"lineax; python_version >= '3.9'",
]
docs = [
"sphinx>=4.0",
Expand Down
1 change: 1 addition & 0 deletions src/ott/solvers/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import (
acceleration,
continuous_barycenter,
Expand Down
163 changes: 95 additions & 68 deletions src/ott/solvers/linear/implicit_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -23,6 +23,10 @@
if TYPE_CHECKING:
from ott.problems.linear import linear_problem

LinOp_t = Callable[[jnp.ndarray], jnp.ndarray]
Solver_t = Callable[[LinOp_t, jnp.ndarray, Optional[LinOp_t], bool],
jnp.ndarray]

__all__ = ["ImplicitDiff"]


Expand All @@ -31,16 +35,28 @@ class ImplicitDiff:
"""Implicit differentiation of Sinkhorn algorithm.

Args:
solver_fun: Callable, should return (solution, ...)
ridge_kernel: promotes zero-sum solutions. only used if tau_a = tau_b = 1.0
ridge_identity: handles rank deficient transport matrices (this happens
typically when rows/cols in cost/kernel matrices are collinear, or,
equivalently when two points from either measure are close).
solver: Callable to compute the solution to a linear problem. The callable
expects a linear function, a vector, optionally another linear function
that implements the transpose of that function, and a boolean flag to
specify symmetry. This solver is by default one of :class:`lineax.CG` or
:class:`lineax.NormalCG` solvers, if the package can be imported, as
described in :func:`~ott.solvers.linear.lineax_implicit.solve_lineax`.
The :mod:`jax` alternative is described in
:func:`~ott.solvers.linear.implicit_differentiation.solve_jax_cg`.
Note that `lineax` solvers handle better poorly conditioned problems,
which arise typically when differentiating the solutions of balanced OT
problems (when ``tau_a==tau_b==1.0``). Relying on
:func:`~ott.solvers.linear.implicit_differentiation.solve_jax_cg`
for such cases might require hand-tuning ridge parameters,
in particular ``ridge_kernel`` and ``ridge_identity`` as described in its
doc. These parameters can be passed using ``solver_kwargs`` below.
solver_kwargs: keyword arguments passed on to the solver.
symmetric: flag used to figure out whether the linear system solved in the
implicit function theorem is symmetric or not. This happens when either
``a == b`` or the precondition_fun is the identity. False by default, and,
at the moment, needs to be set manually by the user in the more favorable
case where the system is guaranteed to be symmetric.
implicit function theorem is symmetric or not. This happens when
``tau_a==tau_b``, and when ``a == b``, or the precondition_fun
is the identity. The flag is False by default, and is also tested against
``tau_a==tau_b``. It needs to be set manually by the user in the more
favorable case where the system is guaranteed to be symmetric.
precondition_fun: Function used to precondition, on both sides, the linear
system derived from first-order conditions of the regularized OT problem.
That linear system typically involves an equality between marginals (or
Expand All @@ -51,17 +67,18 @@ class ImplicitDiff:
theorem differentiation.
"""

solver_fun: Callable[[jnp.ndarray, jnp.ndarray],
Tuple[jnp.ndarray, ...]] = jax.scipy.sparse.linalg.cg
ridge_kernel: float = 0.0
ridge_identity: float = 0.0
solver: Optional[Solver_t] = None
solver_kwargs: Optional[Dict[str, Any]] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to discuss this, am ok with this solutition: alternative solution would be to remove these kwargs and require user to capture any additional keyword arguments using closure/partial.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think avoiding closure/partial is a bit preferable here.

But maybe not clean because IIUC there's no way to mark a Callable that takes optional arguments (...). Another option would be to pass a dictionary (last input = Any) and "fish" variables in there?

symmetric: bool = False
precondition_fun: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None

def solve(
self, gr: Tuple[jnp.ndarray,
jnp.ndarray], ot_prob: "linear_problem.LinearProblem",
f: jnp.ndarray, g: jnp.ndarray, lse_mode: bool
self,
gr: Tuple[jnp.ndarray, jnp.ndarray],
ot_prob: "linear_problem.LinearProblem",
f: jnp.ndarray,
g: jnp.ndarray,
lse_mode: bool,
) -> jnp.ndarray:
r"""Apply minus inverse of [hessian ``reg_ot_cost`` w.r.t. ``f``, ``g``].

Expand Down Expand Up @@ -117,20 +134,9 @@ def solve(
instantiate the Schur complement of the first or of the second diagonal
block.

In either case, the Schur complement is rank deficient, with a 0 eigenvalue
for the vector of ones in the balanced case, which is why we add a ridge on
that subspace to enforce solutions have zero sum.

The Schur complement can also be rank deficient if two lines or columns of T
are collinear. This will typically happen it two rows or columns of the cost
or kernel matrix are numerically close. To avoid this, we add a more global
``ridge_identity * z`` regularizer to achieve better conditioning.

These linear systems are solved using the user defined ``solver_fun``,
which is set by default to ``cg``. When the system is symmetric (as detected
by the corresponding flag ``symmetric``), ``cg`` is applied directly. When
it is not, normal equations are used (i.e. the Schur complement is
multiplied by its transpose before solving the system).
These linear systems are solved using the user-defined ``solver``, using
by default :mod:`lineax` solvers when available, or falling back on
:mod:`jax` when not.

Args:
gr: 2-tuple, (vector of size ``n``, vector of size ``m``).
Expand All @@ -142,17 +148,19 @@ def solve(
Returns:
A tuple of two vectors, of the same size as ``gr``.
"""
solver = _get_solver() if self.solver is None else self.solver
solver_kwargs = {} if self.solver_kwargs is None else self.solver_kwargs
geom = ot_prob.geom
marginal_a, marginal_b, app_transport = (
ot_prob.get_transport_functions(lse_mode)
)

# elementwise vmap apply of derivative of precondition_fun. No vmapping
# can be problematic here.
if self.precondition_fun is None:
precond_fun = lambda x: geom.epsilon * jnp.log(x)
symmetric = False
else:
precond_fun = self.precondition_fun
symmetric = self.symmetric

derivative = jax.vmap(jax.grad(precond_fun))

n, m = geom.shape
Expand All @@ -164,7 +172,7 @@ def solve(
f, g, z * derivative(marginal_a(f, g)), axis=0
) / geom.epsilon

if not self.symmetric:
if not symmetric:
vjp_fgt = lambda z: app_transport(
f, g, z, axis=0
) * derivative(marginal_b(f, g)) / geom.epsilon
Expand All @@ -184,52 +192,32 @@ def solve(
ot_prob.b, g, ot_prob.tau_b, geom.epsilon, derivative
)
)

n, m = geom.shape
# Remove ridge on kernel space if problem is balanced.
ridge_kernel = jnp.where(ot_prob.is_balanced, self.ridge_kernel, 0.0)

# TODO(cuturi) consider materializing linear operator schur if size allows.
# Forks on using Schur complement of either A or D, depending on size.
if n > m: # if n is bigger, run m x m linear system.
inv_vjp_ff = lambda z: z / diag_hess_a
vjp_gg = lambda z: z * diag_hess_b
schur_ = lambda z: vjp_gg(z) - vjp_gf(inv_vjp_ff(vjp_fg(z)))
res = gr[1] - vjp_gf(inv_vjp_ff(gr[0]))

if self.symmetric:
schur = lambda z: (
schur_(z) + ridge_kernel * jnp.sum(z) + self.ridge_identity * z
)
else:
schur = lambda z: vjp_gg(z) - vjp_gf(inv_vjp_ff(vjp_fg(z)))
if not symmetric:
schur_t = lambda z: vjp_gg(z) - vjp_fgt(inv_vjp_ff(vjp_gft(z)))
res = schur_t(res)
schur = lambda z: (
schur_t(schur_(z)) + ridge_kernel * jnp.sum(z) + self.ridge_identity
* z
)

sch = self.solver_fun(schur, res)[0]
else:
schur_t = None
res = gr[1] - vjp_gf(inv_vjp_ff(gr[0]))
sch = solver(schur, res, schur_t, symmetric, **solver_kwargs)
vjp_gr_f = inv_vjp_ff(gr[0] - vjp_fg(sch))
vjp_gr_g = sch
else:
vjp_ff = lambda z: z * diag_hess_a
inv_vjp_gg = lambda z: z / diag_hess_b
schur_ = lambda z: vjp_ff(z) - vjp_fg(inv_vjp_gg(vjp_gf(z)))
res = gr[0] - vjp_fg(inv_vjp_gg(gr[1]))
schur = lambda z: vjp_ff(z) - vjp_fg(inv_vjp_gg(vjp_gf(z)))

if self.symmetric:
schur = lambda z: (
schur_(z) + ridge_kernel * jnp.sum(z) + self.ridge_identity * z
)
else:
if not symmetric:
schur_t = lambda z: vjp_ff(z) - vjp_gft(inv_vjp_gg(vjp_fgt(z)))
res = schur_t(res)
schur = lambda z: (
schur_t(schur_(z)) + ridge_kernel * jnp.sum(z) + self.ridge_identity
* z
)

sch = self.solver_fun(schur, res)[0]
else:
schur_t = None
res = gr[0] - vjp_fg(inv_vjp_gg(gr[1]))
sch = solver(schur, res, schur_t, symmetric, **solver_kwargs)
vjp_gr_g = inv_vjp_gg(gr[1] - vjp_gf(sch))
vjp_gr_f = sch

Expand Down Expand Up @@ -295,3 +283,42 @@ def gradient(

def replace(self, **kwargs: Any) -> "ImplicitDiff": # noqa: D102
return dataclasses.replace(self, **kwargs)


def solve_jax_cg(
lin: LinOp_t,
b: jnp.ndarray,
lin_t: Optional[LinOp_t] = None,
symmetric: bool = False,
ridge_identity: float = 0.0,
ridge_kernel: float = 0.0,
**kwargs: Any
) -> jnp.ndarray:
"""Wrapper around JAX native linear solvers.

Args:
lin: Linear operator
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
b: vector. Returned `x` is such that `lin(x)=b`
lin_t: Linear operator, corresponding to transpose of `lin`.
symmetric: whether `lin` is symmetric.
ridge_kernel: promotes zero-sum solutions. Only use if `tau_a = tau_b = 1.0`
ridge_identity: handles rank deficient transport matrices (this happens
typically when rows/cols in cost/kernel matrices are collinear, or,
equivalently when two points from either measure are close).
kwargs: arguments passed to :func:`~jax.scipy.sparse.linalg.cg`
"""
op = lin if symmetric else lambda x: lin_t(lin(x))
if ridge_kernel > 0.0 or ridge_identity > 0.0:
lin_reg = lambda x: op(x) + ridge_kernel * jnp.sum(x) + ridge_identity * x
else:
lin_reg = op
return jax.scipy.sparse.linalg.cg(lin_reg, b, **kwargs)[0]


def _get_solver() -> Solver_t:
"""Get lineax solver when possible, default to jax.scipy else."""
try:
from ott.solvers.linear import lineax_implicit
return lineax_implicit.solve_lineax
except ImportError:
return solve_jax_cg
84 changes: 84 additions & 0 deletions src/ott/solvers/linear/lineax_implicit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax as lx
from jaxtyping import Array, Float, PyTree

_T = TypeVar("_T")
_FlatPyTree = tuple[list[_T], jtu.PyTreeDef]

__all__ = ["CustomTransposeLinearOperator"]


class CustomTransposeLinearOperator(lx.FunctionLinearOperator):
"""Implement a linear operator that can specify its transpose directly."""
fn: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]]
fn_t: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]]
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field()
input_structure_t: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field()
tags: frozenset[object]

def __init__(self, fn, fn_t, input_structure, input_structure_t, tags=()):
super().__init__(fn, input_structure, tags)
self.fn_t = eqx.filter_closure_convert(fn_t, input_structure_t)
self.input_structure_t = input_structure_t

def transpose(self):
"""Provide custom transposition operator from function."""
return lx.FunctionLinearOperator(self.fn_t, self.input_structure_t)


def solve_lineax(
lin: Callable,
b: jnp.ndarray,
lin_t: Optional[Callable] = None,
symmetric: Optional[bool] = False,
nonsym_solver: Optional[lx.AbstractLinearSolver] = None,
**kwargs: Any
) -> jnp.ndarray:
"""Wrapper around lineax solvers.

Args:
lin: Linear operator
b: vector. Returned `x` is such that `lin(x)=b`
lin_t: Linear operator, corresponding to transpose of `lin`.
symmetric: whether `lin` is symmetric.
nonsym_solver: :class:`~lineax.AbstractLinearSolver` used when handling non
symmetric cases. Note that :class:`~lineax.CG` is used by default in the
symmetric case.
kwargs: arguments passed to :mod:`~lineax.AbstractLinearSolver` linear
solver.
"""
input_structure = jax.eval_shape(lambda: b)
kwargs.setdefault("rtol", 1e-6)
kwargs.setdefault("atol", 1e-6)
if symmetric:
solver = lx.CG(**kwargs)
fn_operator = lx.FunctionLinearOperator(
lin, input_structure, tags=lx.positive_semidefinite_tag
)
return lx.linear_solve(fn_operator, b, solver).value
# In the nonsymmetric case, use NormalCG by default, but consider
# user defined choice of alternative lx solver.
solver_type = lx.NormalCG if nonsym_solver is None else nonsym_solver
solver = solver_type(**kwargs)
fn_operator = CustomTransposeLinearOperator(
lin, lin_t, input_structure, input_structure
)
return lx.linear_solve(fn_operator, b, solver).value
Loading