-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use lineax to solve linear system in implicit diff #370
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
e72b113
use lineax to solve linear system in implicit diff
marcocuturi e243d69
doc
marcocuturi cb56be0
fix
marcocuturi 670fbf9
make lineax solvers optional, add a jax default
marcocuturi e11491b
pydoc
marcocuturi e170573
pydoc
marcocuturi eb438a4
pydoc
marcocuturi ac44e3c
pydoc
marcocuturi 587fdf8
pydoc
marcocuturi bf84695
pydoc
marcocuturi 200159c
selective tests
marcocuturi 7883450
fixing another test
marcocuturi 1ce4386
reintroduce ridge for jax solvers, to pass tests
marcocuturi 91fbdc8
fix again soft-sort using ridge
marcocuturi d2b4793
pydoc
marcocuturi 3abe03b
pydoc.
marcocuturi 978a5a5
lint
marcocuturi 82a181a
increase epsilon to ensure no_precond works.
marcocuturi 0f450df
readded backprop in test hessian + comments
marcocuturi 1c965cf
F401 in unused import.
marcocuturi da14527
change tolerance for kernel mode
marcocuturi 4f35978
remove finite diff / backprop test.
marcocuturi cafdbe0
adding lineax in __init__ for docs.
marcocuturi c79c760
adding back try import in test.
marcocuturi 60a3522
docs + test_back
marcocuturi 2f9d8a0
mod back
marcocuturi c399996
Update readthedocs.yml
michalk8 0ad3b55
Remove `contextlib`
michalk8 293bf93
Fix wrong file name
michalk8 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
version: 2 | ||
build: | ||
image: latest | ||
os: ubuntu-22.04 | ||
tools: | ||
python: '3.10' | ||
|
||
sphinx: | ||
builder: html | ||
configuration: docs/conf.py | ||
fail_on_warning: false | ||
|
||
python: | ||
version: 3.8 | ||
install: | ||
- method: pip | ||
path: . | ||
extra_requirements: | ||
- docs | ||
extra_requirements: [docs] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright OTT-JAX | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Optional, TypeVar | ||
|
||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.tree_util as jtu | ||
import lineax as lx | ||
from jaxtyping import Array, Float, PyTree | ||
|
||
_T = TypeVar("_T") | ||
_FlatPyTree = tuple[list[_T], jtu.PyTreeDef] | ||
|
||
__all__ = ["CustomTransposeLinearOperator"] | ||
|
||
|
||
class CustomTransposeLinearOperator(lx.FunctionLinearOperator): | ||
"""Implement a linear operator that can specify its transpose directly.""" | ||
fn: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]] | ||
fn_t: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]] | ||
input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field() | ||
input_structure_t: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field() | ||
tags: frozenset[object] | ||
|
||
def __init__(self, fn, fn_t, input_structure, input_structure_t, tags=()): | ||
super().__init__(fn, input_structure, tags) | ||
self.fn_t = eqx.filter_closure_convert(fn_t, input_structure_t) | ||
self.input_structure_t = input_structure_t | ||
|
||
def transpose(self): | ||
"""Provide custom transposition operator from function.""" | ||
return lx.FunctionLinearOperator(self.fn_t, self.input_structure_t) | ||
|
||
|
||
def solve_lineax( | ||
lin: Callable, | ||
b: jnp.ndarray, | ||
lin_t: Optional[Callable] = None, | ||
symmetric: Optional[bool] = False, | ||
nonsym_solver: Optional[lx.AbstractLinearSolver] = None, | ||
**kwargs: Any | ||
) -> jnp.ndarray: | ||
"""Wrapper around lineax solvers. | ||
|
||
Args: | ||
lin: Linear operator | ||
b: vector. Returned `x` is such that `lin(x)=b` | ||
lin_t: Linear operator, corresponding to transpose of `lin`. | ||
symmetric: whether `lin` is symmetric. | ||
nonsym_solver: :class:`~lineax.AbstractLinearSolver` used when handling non | ||
symmetric cases. Note that :class:`~lineax.CG` is used by default in the | ||
symmetric case. | ||
kwargs: arguments passed to :mod:`~lineax.AbstractLinearSolver` linear | ||
solver. | ||
""" | ||
input_structure = jax.eval_shape(lambda: b) | ||
kwargs.setdefault("rtol", 1e-6) | ||
kwargs.setdefault("atol", 1e-6) | ||
if symmetric: | ||
solver = lx.CG(**kwargs) | ||
fn_operator = lx.FunctionLinearOperator( | ||
lin, input_structure, tags=lx.positive_semidefinite_tag | ||
) | ||
return lx.linear_solve(fn_operator, b, solver).value | ||
# In the nonsymmetric case, use NormalCG by default, but consider | ||
# user defined choice of alternative lx solver. | ||
solver_type = lx.NormalCG if nonsym_solver is None else nonsym_solver | ||
solver = solver_type(**kwargs) | ||
fn_operator = CustomTransposeLinearOperator( | ||
lin, lin_t, input_structure, input_structure | ||
) | ||
return lx.linear_solve(fn_operator, b, solver).value |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to discuss this, am ok with this solutition: alternative solution would be to remove these kwargs and require user to capture any additional keyword arguments using closure/partial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think avoiding closure/partial is a bit preferable here.
But maybe not clean because IIUC there's no way to mark a Callable that takes optional arguments (...). Another option would be to pass a dictionary (last input = Any) and "fish" variables in there?