Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use lineax to solve linear system in implicit diff #370

Merged
merged 29 commits into from
Jun 20, 2023
Merged

use lineax to solve linear system in implicit diff #370

merged 29 commits into from
Jun 20, 2023

Conversation

marcocuturi
Copy link
Contributor

@marcocuturi marcocuturi commented Jun 13, 2023

  • use lineax CG (symmetric case) or NormalCG (not symmetric case) to handle linear solves, when possible (at this moment this hinges mostly on whether is running python >=3.9). Default back to JAX's native scipy.sparse.linalg.cg (or normal CG in unsymmetric case) when lineax cannot be imported.
  • ridge parameters in linear solve can still be used for JAX solvers (ridge_kernel and ridge_identity) but are not handled for lineax solvers.
  • fix hessians notebook:
    • remove reference to identity preconditioning (everything uses default parameters now)
    • add check of eigenvalues of hessian and numerical accuracy of hessian approx.
  • some tests are now only tested with lineax.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@michalk8
Copy link
Collaborator

@marcocuturi the failing tests can be remedied by adding an importorskip, e.g., inside the test functions that use lineax, I'd do _ = pytest.importorskip("lineax") # some comment explaining while this is needed.

@codecov-commenter
Copy link

codecov-commenter commented Jun 15, 2023

Codecov Report

Merging #370 (293bf93) into main (31df701) will decrease coverage by 0.71%.
The diff coverage is 37.87%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #370      +/-   ##
==========================================
- Coverage   88.43%   87.73%   -0.71%     
==========================================
  Files          51       52       +1     
  Lines        5605     5650      +45     
  Branches      836      573     -263     
==========================================
  Hits         4957     4957              
- Misses        529      568      +39     
- Partials      119      125       +6     
Impacted Files Coverage Δ
src/ott/solvers/linear/lineax_implicit.py 5.88% <5.88%> (ø)
src/ott/solvers/linear/implicit_differentiation.py 82.02% <71.87%> (-14.09%) ⬇️

... and 1 file with indirect coverage changes

@marcocuturi
Copy link
Contributor Author

@marcocuturi the failing tests can be remedied by adding an importorskip, e.g., inside the test functions that use lineax, I'd do _ = pytest.importorskip("lineax") # some comment explaining while this is needed.

I went for something in the middle, sometimes I do use the ridge parameters so that JAX solvers run.

@marcocuturi marcocuturi requested a review from michalk8 June 16, 2023 09:06
src/ott/solvers/linear/implicit_differentiation.py Outdated Show resolved Hide resolved
src/ott/solvers/linear/implicit_differentiation.py Outdated Show resolved Hide resolved
src/ott/solvers/linear/implicit_differentiation.py Outdated Show resolved Hide resolved
specify symmetry. This solver is by default one of `lineax`'s `CG` or
`NormalCG` solvers, if the package can be imported, as described in
:func:`~ott.solvers.linear.lineax_implicit.solve_lineax`.
The pure `JAX` alternative is described in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:mod:jax

ridge_kernel: float = 0.0
ridge_identity: float = 0.0
solver: Optional[Solver_t] = None
solver_kwargs: Optional[Dict[str, Any]] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to discuss this, am ok with this solutition: alternative solution would be to remove these kwargs and require user to capture any additional keyword arguments using closure/partial.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think avoiding closure/partial is a bit preferable here.

But maybe not clean because IIUC there's no way to mark a Callable that takes optional arguments (...). Another option would be to pass a dictionary (last input = Any) and "fish" variables in there?

src/ott/solvers/linear/lineax_implicit.py Outdated Show resolved Hide resolved
src/ott/solvers/linear/lineax_implicit.py Outdated Show resolved Hide resolved
kwargs.setdefault("rtol", 1e-6)
kwargs.setdefault("atol", 1e-6)
# Ridge parameters passed to JAX solvers are ignored in lineax.
_ = kwargs.pop("ridge_identity", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these need to be popped? They are no default arguments in kwargs; if it's just because of the tests, would remove the kwargs.pop() here and adjust the tests instead.

Copy link
Contributor Author

@marcocuturi marcocuturi Jun 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, this was because of tests, because this requires now to try importing lineax at test time (rather than when running the implicit solver), but will do.


Args:
lin: Linear operator
b: vector such that sought `x` is such that `lin(x)=b`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too much "such that", would rephrase as:
Vector :math:b such that :math:A x = b.

And update the docs of the other function as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added something a bit less mathematical! (if we want to be math rigorous, it should be A(x))

@michalk8 michalk8 merged commit 428316c into main Jun 20, 2023
@michalk8 michalk8 deleted the mott branch June 20, 2023 13:07
@michalk8 michalk8 restored the mott branch June 20, 2023 13:08
@michalk8 michalk8 deleted the mott branch June 20, 2023 13:08
michalk8 added a commit that referenced this pull request Jun 27, 2024
* use lineax to solve linear system in implicit diff

* doc

* fix

* make lineax solvers optional, add a jax default

* pydoc

* pydoc

* pydoc

* pydoc

* pydoc

* pydoc

* selective tests

* fixing another test

* reintroduce ridge for jax solvers, to pass tests

* fix again soft-sort using ridge

* pydoc

* pydoc.

* lint

* increase epsilon to ensure no_precond works.

* readded backprop in test hessian + comments

* F401 in unused import.

* change tolerance for kernel mode

* remove finite diff / backprop test.

* adding lineax in __init__ for docs.

* adding back try import in test.

* docs + test_back

* mod back

* Update readthedocs.yml

* Remove `contextlib`

* Fix wrong file name

---------

Co-authored-by: Michal Klein <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants