-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use QR factorization for least squares #1050
Conversation
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +10.00 +/- 6.37 | +5.11e-02 +/- 3.26e-02 | 5.62e-01 +/- 2.8e-02 | 5.11e-01 +/- 1.7e-02 |
test_build_transform_fft_midres | +4.88 +/- 5.52 | +2.95e-02 +/- 3.33e-02 | 6.33e-01 +/- 1.4e-02 | 6.03e-01 +/- 3.0e-02 |
test_build_transform_fft_highres | +2.33 +/- 6.31 | +2.35e-02 +/- 6.38e-02 | 1.03e+00 +/- 4.4e-02 | 1.01e+00 +/- 4.6e-02 |
test_equilibrium_init_lowres | +2.40 +/- 2.79 | +9.00e-02 +/- 1.04e-01 | 3.84e+00 +/- 6.9e-02 | 3.75e+00 +/- 7.9e-02 |
test_equilibrium_init_medres | +0.29 +/- 4.40 | +1.22e-02 +/- 1.88e-01 | 4.28e+00 +/- 1.3e-01 | 4.27e+00 +/- 1.4e-01 |
test_equilibrium_init_highres | -0.00 +/- 5.15 | -1.92e-04 +/- 2.95e-01 | 5.73e+00 +/- 6.0e-02 | 5.73e+00 +/- 2.9e-01 |
test_objective_compile_dshape_current | -3.18 +/- 2.21 | -1.27e-01 +/- 8.79e-02 | 3.86e+00 +/- 4.5e-02 | 3.98e+00 +/- 7.5e-02 |
test_objective_compile_atf | -3.49 +/- 1.81 | -3.02e-01 +/- 1.57e-01 | 8.36e+00 +/- 9.3e-02 | 8.66e+00 +/- 1.3e-01 |
test_objective_compute_dshape_current | -0.22 +/- 5.90 | -2.81e-06 +/- 7.50e-05 | 1.27e-03 +/- 5.5e-05 | 1.27e-03 +/- 5.1e-05 |
test_objective_compute_atf | -2.24 +/- 9.50 | -9.80e-05 +/- 4.16e-04 | 4.28e-03 +/- 2.3e-04 | 4.38e-03 +/- 3.5e-04 |
test_objective_jac_dshape_current | -6.41 +/- 7.98 | -2.55e-03 +/- 3.17e-03 | 3.72e-02 +/- 1.8e-03 | 3.98e-02 +/- 2.6e-03 |
test_objective_jac_atf | -5.13 +/- 2.62 | -1.02e-01 +/- 5.23e-02 | 1.89e+00 +/- 3.3e-02 | 1.99e+00 +/- 4.1e-02 |
test_perturb_1 | -3.25 +/- 2.04 | -4.53e-01 +/- 2.84e-01 | 1.35e+01 +/- 1.1e-01 | 1.39e+01 +/- 2.6e-01 |
test_perturb_2 | +1.36 +/- 2.34 | +2.51e-01 +/- 4.30e-01 | 1.86e+01 +/- 3.2e-01 | 1.84e+01 +/- 2.9e-01 |
test_proximal_jac_atf | -1.61 +/- 2.11 | -1.20e-01 +/- 1.58e-01 | 7.35e+00 +/- 1.3e-01 | 7.47e+00 +/- 8.8e-02 |
test_proximal_freeb_compute | -0.79 +/- 1.89 | -1.43e-03 +/- 3.41e-03 | 1.79e-01 +/- 1.6e-03 | 1.80e-01 +/- 3.0e-03 |
test_proximal_freeb_jac | -1.09 +/- 2.67 | -8.15e-02 +/- 2.00e-01 | 7.40e+00 +/- 1.3e-01 | 7.48e+00 +/- 1.5e-01 |
test_solve_fixed_iter | +17.48 +/- 9.62 | +2.74e+00 +/- 1.51e+00 | 1.84e+01 +/- 8.7e-01 | 1.57e+01 +/- 1.2e+00 | |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1050 +/- ##
==========================================
+ Coverage 95.26% 95.28% +0.01%
==========================================
Files 87 87
Lines 21719 21776 +57
==========================================
+ Hits 20691 20749 +58
+ Misses 1028 1027 -1
|
This reverts commit cb49194.
Should be merged before #1050 so benchmark will run.
the solve time did not change appreciably for the fixed iter? that is strange, is QR set to default? |
It is, I'm surprised too. All of the benchmarks look really noisy though so I'll try just rerunning them |
It looks like it might be slower still, though the uncertainty is high so hard to tell. Might also be resolution dependent. |
@@ -285,6 +287,10 @@ def lsqtr( # noqa: C901 - FIXME: simplify this | |||
step_h, hits_boundary, alpha = trust_region_step_exact_cho( | |||
g_h, B_h, trust_radius, alpha | |||
) | |||
elif tr_method == "qr": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can calculate p_newton
at the top and send it to this subproblem function. We can add another elif at line 267. This should save us 1 QR decomposition per inner while loop (this loop, not the one inside trust_region_step_exact_qr
function)
p_newton = solve_triangular_regularized(R, -Q.T @ f) | ||
else: | ||
Q, R = qr(J.T, mode="economic") | ||
p_newton = Q @ solve_triangular_regularized(R.T, f, lower=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of p_newton
doesn't depend on any variable inside this function nor J or f changes inside the while which calls this function. I think we should be able to calculate p_newton
once per full optimization iteration to use it as check for trust region algorithm.
k += 1 | ||
return alpha, alpha_lower, alpha_upper, phi, k | ||
|
||
alpha, *_ = while_loop( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A single QR is faster than SVD on both CPU and GPU, but this while_loop
is very slow. So, if we ever need to call falsefun
that iteration is significantly slower than SVD on CPU. For GPU, I guess the execution of this while_loop
is not that slow so we saw a speed-up with QR (not single QR but overall solution time decreased). Is there a better way to do this part?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really if you want the exact solution to the subproblem. You can get an approximate solution using the dogleg or subspace method but I've found those don't perform as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will look into ways of optimizing how to decompose Ji
given that you know decomposition J=QR. I am not sure but there should be some linear algebra trick we can use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are methods for updating a QR factorization when you add rows. Suppose we have
what we want is
for different values of
For reference, here are some profiling on CPU, (L=M=N=10)
|
Here are some timings on cpu and gpu for just a single SVD/QR: GPU:
CPU:
looks like a single QR is ~10x faster on GPU, ~5x faster on CPU. So if we're taking full newton steps, QR is always better. But if the trust region constraint is binding its less clear cut, and which is faster will depend on the hardware and how many steps are needed in the rootfinding |
Should there be a test comparing SVD and QR? |
def truefun(*_): | ||
return p_newton, False, 0.0 | ||
|
||
def falsefun(*_): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this is what's used if there is no exact solution within the trust region? could you give me a general overview of the function? I can't really tell what's going on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically we first check to see if the full newton step is within the trust region. If it is, we use that (truefun
).
Otherwise, we do some rootfinding in the regularization parameter alpha
to find the regularized step that lies on the boundary of the trust region (falsefun
).
|
||
return p, True, alpha | ||
|
||
return cond(jnp.linalg.norm(p_newton) <= trust_radius, truefun, falsefun, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do truefun
and falsefun
use the *_
that is defined in within this trust_region_step_exact_qr
function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, that's just a dummy variable to eat up the other return values from the loop. The only output we need from the loop is the final value of alpha.
|
||
@functools.partial(jit, static_argnames="lower") | ||
def solve_triangular_regularized(R, b, lower=False): | ||
"""Solve Rx=b for triangular, possibly rank deficient R. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring for args necessary? (i.e. I don't know what lower
is)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It solves Rx=b for lower or upper triangular R matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not public, so I don't think a full doc is necessary. It's basically the same as jax.scipy.linalg.solve_triangular
but with a bit of pre-processing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think that we can take p_newton
calculation out of the inner while loop. I assume there will be a follow-up PR about QR to make falsefun
faster (at least I plan to work on that), so I can try some additional optimizations there. For now, this version is fine and works better than SVD on GPU.
Resolves #1078 Some performance improvements for QR decomposition used in optimization which was first introduced in #1050. - Take the `p_newton` calculation out of inner while loop, since it is basically calculating the same QR over and over again - ~Use proper QR update procedure for the `falsefun` in `trust_region_step_exact_qr`. That is we already now QR decomposition of `J=QR`, if we stack a diagonal matrix `aI` to `J` then instead of taking the whole QR decomposition again, there is a more clever way of updating the QR.There are methods for updating a QR factorization when you add rows. Suppose we have~ $$ QR = J $$ what we want is $$ \tilde{Q} \tilde{R} = \begin{pmatrix} J \\ \alpha I \end{pmatrix} $$ The QR update procedure can be implemented on a later PR with Householder matrices, but for now, it seems a bit inefficient to implement using JAX since QR is calculated by Fortran package LAPACK on Scipy and Jax, our custom QR'ish thing will be slow.
Adds the option to use QR decomposition for solving least squares trust region problem. This appears to be significantly faster than SVD with negligible loss in accuracy (at least on our tests). This also makes QR the new default instead of SVD.
There are a few other places that we use SVD that could possibly also be replaced by QR:
perturb.py
, though probably not worth it here, since initial SVD and initial QR take about the same time, but SVD can reuse the factorization for 2nd and higher order perturbations.ProximalProjection._jvp_f
- I tried replacing this but it seems here we need the extra accuracy. Could maybe get it to work if we also implement some of the ideas from Ideas to makeProximalProjection
more robust #802 but that can wait for another PR.Resolves #708