-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use lineax to solve linear system in implicit diff (#370)
* use lineax to solve linear system in implicit diff * doc * fix * make lineax solvers optional, add a jax default * pydoc * pydoc * pydoc * pydoc * pydoc * pydoc * selective tests * fixing another test * reintroduce ridge for jax solvers, to pass tests * fix again soft-sort using ridge * pydoc * pydoc. * lint * increase epsilon to ensure no_precond works. * readded backprop in test hessian + comments * F401 in unused import. * change tolerance for kernel mode * remove finite diff / backprop test. * adding lineax in __init__ for docs. * adding back try import in test. * docs + test_back * mod back * Update readthedocs.yml * Remove `contextlib` * Fix wrong file name --------- Co-authored-by: Michal Klein <[email protected]>
- Loading branch information
1 parent
31df701
commit 428316c
Showing
10 changed files
with
405 additions
and
195 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
version: 2 | ||
build: | ||
image: latest | ||
os: ubuntu-22.04 | ||
tools: | ||
python: '3.10' | ||
|
||
sphinx: | ||
builder: html | ||
configuration: docs/conf.py | ||
fail_on_warning: false | ||
|
||
python: | ||
version: 3.8 | ||
install: | ||
- method: pip | ||
path: . | ||
extra_requirements: | ||
- docs | ||
extra_requirements: [docs] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Optional, TypeVar | ||
|
||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.tree_util as jtu | ||
import lineax as lx | ||
from jaxtyping import Array, Float, PyTree | ||
|
||
_T = TypeVar("_T") | ||
_FlatPyTree = tuple[list[_T], jtu.PyTreeDef] | ||
|
||
__all__ = ["CustomTransposeLinearOperator"] | ||
|
||
|
||
class CustomTransposeLinearOperator(lx.FunctionLinearOperator): | ||
"""Implement a linear operator that can specify its transpose directly.""" | ||
fn: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]] | ||
fn_t: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]] | ||
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field() | ||
input_structure_t: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field() | ||
tags: frozenset[object] | ||
|
||
def __init__(self, fn, fn_t, input_structure, input_structure_t, tags=()): | ||
super().__init__(fn, input_structure, tags) | ||
self.fn_t = eqx.filter_closure_convert(fn_t, input_structure_t) | ||
self.input_structure_t = input_structure_t | ||
|
||
def transpose(self): | ||
"""Provide custom transposition operator from function.""" | ||
return lx.FunctionLinearOperator(self.fn_t, self.input_structure_t) | ||
|
||
|
||
def solve_lineax( | ||
lin: Callable, | ||
b: jnp.ndarray, | ||
lin_t: Optional[Callable] = None, | ||
symmetric: Optional[bool] = False, | ||
nonsym_solver: Optional[lx.AbstractLinearSolver] = None, | ||
**kwargs: Any | ||
) -> jnp.ndarray: | ||
"""Wrapper around lineax solvers. | ||
Args: | ||
lin: Linear operator | ||
b: vector. Returned `x` is such that `lin(x)=b` | ||
lin_t: Linear operator, corresponding to transpose of `lin`. | ||
symmetric: whether `lin` is symmetric. | ||
nonsym_solver: :class:`~lineax.AbstractLinearSolver` used when handling non | ||
symmetric cases. Note that :class:`~lineax.CG` is used by default in the | ||
symmetric case. | ||
kwargs: arguments passed to :mod:`~lineax.AbstractLinearSolver` linear | ||
solver. | ||
""" | ||
input_structure = jax.eval_shape(lambda: b) | ||
kwargs.setdefault("rtol", 1e-6) | ||
kwargs.setdefault("atol", 1e-6) | ||
if symmetric: | ||
solver = lx.CG(**kwargs) | ||
fn_operator = lx.FunctionLinearOperator( | ||
lin, input_structure, tags=lx.positive_semidefinite_tag | ||
) | ||
return lx.linear_solve(fn_operator, b, solver).value | ||
# In the nonsymmetric case, use NormalCG by default, but consider | ||
# user defined choice of alternative lx solver. | ||
solver_type = lx.NormalCG if nonsym_solver is None else nonsym_solver | ||
solver = solver_type(**kwargs) | ||
fn_operator = CustomTransposeLinearOperator( | ||
lin, lin_t, input_structure, input_structure | ||
) | ||
return lx.linear_solve(fn_operator, b, solver).value |
Oops, something went wrong.