-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use lineax to solve linear system in implicit diff #370
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@marcocuturi the failing tests can be remedied by adding an importorskip, e.g., inside the test functions that use |
Codecov Report
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## main #370 +/- ##
==========================================
- Coverage 88.43% 87.73% -0.71%
==========================================
Files 51 52 +1
Lines 5605 5650 +45
Branches 836 573 -263
==========================================
Hits 4957 4957
- Misses 529 568 +39
- Partials 119 125 +6
|
I went for something in the middle, sometimes I do use the ridge parameters so that JAX solvers run. |
specify symmetry. This solver is by default one of `lineax`'s `CG` or | ||
`NormalCG` solvers, if the package can be imported, as described in | ||
:func:`~ott.solvers.linear.lineax_implicit.solve_lineax`. | ||
The pure `JAX` alternative is described in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:mod:jax
ridge_kernel: float = 0.0 | ||
ridge_identity: float = 0.0 | ||
solver: Optional[Solver_t] = None | ||
solver_kwargs: Optional[Dict[str, Any]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to discuss this, am ok with this solutition: alternative solution would be to remove these kwargs and require user to capture any additional keyword arguments using closure/partial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think avoiding closure/partial is a bit preferable here.
But maybe not clean because IIUC there's no way to mark a Callable that takes optional arguments (...). Another option would be to pass a dictionary (last input = Any) and "fish" variables in there?
kwargs.setdefault("rtol", 1e-6) | ||
kwargs.setdefault("atol", 1e-6) | ||
# Ridge parameters passed to JAX solvers are ignored in lineax. | ||
_ = kwargs.pop("ridge_identity", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do these need to be popped? They are no default arguments in kwargs
; if it's just because of the tests, would remove the kwargs.pop()
here and adjust the tests instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, this was because of tests, because this requires now to try importing lineax at test time (rather than when running the implicit solver), but will do.
|
||
Args: | ||
lin: Linear operator | ||
b: vector such that sought `x` is such that `lin(x)=b` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too much "such that", would rephrase as:
Vector :math:b
such that :math:A x = b
.
And update the docs of the other function as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added something a bit less mathematical! (if we want to be math rigorous, it should be A(x)
)
* use lineax to solve linear system in implicit diff * doc * fix * make lineax solvers optional, add a jax default * pydoc * pydoc * pydoc * pydoc * pydoc * pydoc * selective tests * fixing another test * reintroduce ridge for jax solvers, to pass tests * fix again soft-sort using ridge * pydoc * pydoc. * lint * increase epsilon to ensure no_precond works. * readded backprop in test hessian + comments * F401 in unused import. * change tolerance for kernel mode * remove finite diff / backprop test. * adding lineax in __init__ for docs. * adding back try import in test. * docs + test_back * mod back * Update readthedocs.yml * Remove `contextlib` * Fix wrong file name --------- Co-authored-by: Michal Klein <[email protected]>
CG
(symmetric case) orNormalCG
(not symmetric case) to handle linear solves, when possible (at this moment this hinges mostly on whether is running python >=3.9). Default back to JAX's nativescipy.sparse.linalg.cg
(or normal CG in unsymmetric case) when lineax cannot be imported.ridge_kernel
andridge_identity
) but are not handled for lineax solvers.lineax
.