Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use QR factorization for least squares #1050

Merged
merged 25 commits into from
Jul 18, 2024
Merged

Use QR factorization for least squares #1050

merged 25 commits into from
Jul 18, 2024

Conversation

f0uriest
Copy link
Member

@f0uriest f0uriest commented Jun 12, 2024

Adds the option to use QR decomposition for solving least squares trust region problem. This appears to be significantly faster than SVD with negligible loss in accuracy (at least on our tests). This also makes QR the new default instead of SVD.

There are a few other places that we use SVD that could possibly also be replaced by QR:

  • In perturb.py, though probably not worth it here, since initial SVD and initial QR take about the same time, but SVD can reuse the factorization for 2nd and higher order perturbations.
  • In ProximalProjection._jvp_f - I tried replacing this but it seems here we need the extra accuracy. Could maybe get it to work if we also implement some of the ideas from Ideas to make ProximalProjection more robust #802 but that can wait for another PR.

Resolves #708

@f0uriest f0uriest marked this pull request as draft June 12, 2024 20:46
desc/optimize/aug_lagrangian_ls.py Outdated Show resolved Hide resolved
desc/optimize/aug_lagrangian_ls.py Outdated Show resolved Hide resolved
desc/optimize/aug_lagrangian_ls.py Show resolved Hide resolved
desc/optimize/least_squares.py Outdated Show resolved Hide resolved
desc/optimize/least_squares.py Outdated Show resolved Hide resolved
desc/optimize/least_squares.py Show resolved Hide resolved
desc/optimize/tr_subproblems.py Show resolved Hide resolved
Copy link
Contributor

github-actions bot commented Jun 12, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |    +10.00 +/- 6.37     | +5.11e-02 +/- 3.26e-02 |  5.62e-01 +/- 2.8e-02  |  5.11e-01 +/- 1.7e-02  |
 test_build_transform_fft_midres         |     +4.88 +/- 5.52     | +2.95e-02 +/- 3.33e-02 |  6.33e-01 +/- 1.4e-02  |  6.03e-01 +/- 3.0e-02  |
 test_build_transform_fft_highres        |     +2.33 +/- 6.31     | +2.35e-02 +/- 6.38e-02 |  1.03e+00 +/- 4.4e-02  |  1.01e+00 +/- 4.6e-02  |
 test_equilibrium_init_lowres            |     +2.40 +/- 2.79     | +9.00e-02 +/- 1.04e-01 |  3.84e+00 +/- 6.9e-02  |  3.75e+00 +/- 7.9e-02  |
 test_equilibrium_init_medres            |     +0.29 +/- 4.40     | +1.22e-02 +/- 1.88e-01 |  4.28e+00 +/- 1.3e-01  |  4.27e+00 +/- 1.4e-01  |
 test_equilibrium_init_highres           |     -0.00 +/- 5.15     | -1.92e-04 +/- 2.95e-01 |  5.73e+00 +/- 6.0e-02  |  5.73e+00 +/- 2.9e-01  |
 test_objective_compile_dshape_current   |     -3.18 +/- 2.21     | -1.27e-01 +/- 8.79e-02 |  3.86e+00 +/- 4.5e-02  |  3.98e+00 +/- 7.5e-02  |
 test_objective_compile_atf              |     -3.49 +/- 1.81     | -3.02e-01 +/- 1.57e-01 |  8.36e+00 +/- 9.3e-02  |  8.66e+00 +/- 1.3e-01  |
 test_objective_compute_dshape_current   |     -0.22 +/- 5.90     | -2.81e-06 +/- 7.50e-05 |  1.27e-03 +/- 5.5e-05  |  1.27e-03 +/- 5.1e-05  |
 test_objective_compute_atf              |     -2.24 +/- 9.50     | -9.80e-05 +/- 4.16e-04 |  4.28e-03 +/- 2.3e-04  |  4.38e-03 +/- 3.5e-04  |
 test_objective_jac_dshape_current       |     -6.41 +/- 7.98     | -2.55e-03 +/- 3.17e-03 |  3.72e-02 +/- 1.8e-03  |  3.98e-02 +/- 2.6e-03  |
 test_objective_jac_atf                  |     -5.13 +/- 2.62     | -1.02e-01 +/- 5.23e-02 |  1.89e+00 +/- 3.3e-02  |  1.99e+00 +/- 4.1e-02  |
 test_perturb_1                          |     -3.25 +/- 2.04     | -4.53e-01 +/- 2.84e-01 |  1.35e+01 +/- 1.1e-01  |  1.39e+01 +/- 2.6e-01  |
 test_perturb_2                          |     +1.36 +/- 2.34     | +2.51e-01 +/- 4.30e-01 |  1.86e+01 +/- 3.2e-01  |  1.84e+01 +/- 2.9e-01  |
 test_proximal_jac_atf                   |     -1.61 +/- 2.11     | -1.20e-01 +/- 1.58e-01 |  7.35e+00 +/- 1.3e-01  |  7.47e+00 +/- 8.8e-02  |
 test_proximal_freeb_compute             |     -0.79 +/- 1.89     | -1.43e-03 +/- 3.41e-03 |  1.79e-01 +/- 1.6e-03  |  1.80e-01 +/- 3.0e-03  |
 test_proximal_freeb_jac                 |     -1.09 +/- 2.67     | -8.15e-02 +/- 2.00e-01 |  7.40e+00 +/- 1.3e-01  |  7.48e+00 +/- 1.5e-01  |
 test_solve_fixed_iter                   |    +17.48 +/- 9.62     | +2.74e+00 +/- 1.51e+00 |  1.84e+01 +/- 8.7e-01  |  1.57e+01 +/- 1.2e+00  |

Copy link

codecov bot commented Jun 13, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 95.28%. Comparing base (6bc9955) to head (2a73d92).
Report is 1898 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1050      +/-   ##
==========================================
+ Coverage   95.26%   95.28%   +0.01%     
==========================================
  Files          87       87              
  Lines       21719    21776      +57     
==========================================
+ Hits        20691    20749      +58     
+ Misses       1028     1027       -1     
Files with missing lines Coverage Δ
desc/optimize/aug_lagrangian_ls.py 96.51% <100.00%> (-0.96%) ⬇️
desc/optimize/least_squares.py 99.30% <100.00%> (+<0.01%) ⬆️
desc/optimize/tr_subproblems.py 99.46% <100.00%> (+1.54%) ⬆️
desc/optimize/utils.py 94.95% <100.00%> (+0.19%) ⬆️

... and 1 file with indirect coverage changes

f0uriest added a commit that referenced this pull request Jun 14, 2024
Should be merged before #1050 so benchmark will run.
@dpanici
Copy link
Collaborator

dpanici commented Jun 14, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |    +11.98 +/- 10.94    | +7.05e-02 +/- 6.44e-02 |  6.59e-01 +/- 5.5e-02  |  5.89e-01 +/- 3.4e-02  |
 test_build_transform_fft_midres         |     +3.72 +/- 11.45    | +2.64e-02 +/- 8.12e-02 |  7.36e-01 +/- 4.1e-02  |  7.09e-01 +/- 7.0e-02  |
 test_build_transform_fft_highres        |     -0.22 +/- 7.35     | -2.49e-03 +/- 8.53e-02 |  1.16e+00 +/- 4.7e-02  |  1.16e+00 +/- 7.1e-02  |
 test_equilibrium_init_lowres            |     +4.28 +/- 10.44    | +1.91e-01 +/- 4.67e-01 |  4.67e+00 +/- 2.5e-01  |  4.48e+00 +/- 4.0e-01  |
 test_equilibrium_init_medres            |     +0.18 +/- 9.85     | +9.43e-03 +/- 5.10e-01 |  5.19e+00 +/- 3.2e-01  |  5.18e+00 +/- 4.0e-01  |
 test_equilibrium_init_highres           |     +7.99 +/- 7.43     | +5.26e-01 +/- 4.89e-01 |  7.11e+00 +/- 2.5e-01  |  6.58e+00 +/- 4.2e-01  |
 test_objective_compile_dshape_current   |     +3.87 +/- 8.57     | +1.62e-01 +/- 3.58e-01 |  4.34e+00 +/- 3.0e-01  |  4.18e+00 +/- 2.0e-01  |
+test_objective_compile_atf              |    -24.02 +/- 5.19     | -3.04e+00 +/- 6.56e-01 |  9.60e+00 +/- 4.8e-01  |  1.26e+01 +/- 4.5e-01  |
 test_objective_compute_dshape_current   |     +5.91 +/- 5.87     | +7.11e-05 +/- 7.07e-05 |  1.27e-03 +/- 5.9e-05  |  1.20e-03 +/- 3.8e-05  |
 test_objective_compute_atf              |    -19.20 +/- 20.40    | -1.13e-03 +/- 1.20e-03 |  4.74e-03 +/- 4.8e-04  |  5.86e-03 +/- 1.1e-03  |
 test_objective_jac_dshape_current       |     -3.43 +/- 17.46    | -1.55e-03 +/- 7.89e-03 |  4.36e-02 +/- 4.5e-03  |  4.52e-02 +/- 6.5e-03  |
 test_objective_jac_atf                  |     +0.58 +/- 21.34    | +1.61e-02 +/- 5.92e-01 |  2.79e+00 +/- 4.8e-01  |  2.77e+00 +/- 3.4e-01  |
 test_perturb_1                          |     -1.80 +/- 10.38    | -2.88e-01 +/- 1.66e+00 |  1.57e+01 +/- 5.5e-01  |  1.60e+01 +/- 1.6e+00  |
 test_perturb_2                          |    +15.78 +/- 7.00     | +3.19e+00 +/- 1.42e+00 |  2.34e+01 +/- 1.3e+00  |  2.02e+01 +/- 5.4e-01  |
 test_proximal_jac_atf                   |     +9.60 +/- 7.74     | +7.60e-01 +/- 6.13e-01 |  8.68e+00 +/- 5.5e-01  |  7.92e+00 +/- 2.7e-01  |
 test_proximal_freeb_compute             |     +3.35 +/- 3.70     | +6.03e-03 +/- 6.67e-03 |  1.86e-01 +/- 6.5e-03  |  1.80e-01 +/- 1.7e-03  |
 test_proximal_freeb_jac                 |     +0.25 +/- 3.53     | +1.95e-02 +/- 2.73e-01 |  7.76e+00 +/- 2.1e-01  |  7.74e+00 +/- 1.8e-01  |
 test_solve_fixed_iter                   |    +16.43 +/- 21.12    | +2.88e+00 +/- 3.70e+00 |  2.04e+01 +/- 2.4e+00  |  1.75e+01 +/- 2.8e+00  |

the solve time did not change appreciably for the fixed iter? that is strange, is QR set to default?

@f0uriest
Copy link
Member Author

It is, I'm surprised too.

All of the benchmarks look really noisy though so I'll try just rerunning them

@f0uriest
Copy link
Member Author

It looks like it might be slower still, though the uncertainty is high so hard to tell.
@YigitElma can you try profiling svd vs qr on cpu to see if there's a big difference?

Might also be resolution dependent.

@@ -285,6 +287,10 @@ def lsqtr( # noqa: C901 - FIXME: simplify this
step_h, hits_boundary, alpha = trust_region_step_exact_cho(
g_h, B_h, trust_radius, alpha
)
elif tr_method == "qr":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can calculate p_newton at the top and send it to this subproblem function. We can add another elif at line 267. This should save us 1 QR decomposition per inner while loop (this loop, not the one inside trust_region_step_exact_qr function)

p_newton = solve_triangular_regularized(R, -Q.T @ f)
else:
Q, R = qr(J.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, f, lower=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation of p_newton doesn't depend on any variable inside this function nor J or f changes inside the while which calls this function. I think we should be able to calculate p_newton once per full optimization iteration to use it as check for trust region algorithm.

k += 1
return alpha, alpha_lower, alpha_upper, phi, k

alpha, *_ = while_loop(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A single QR is faster than SVD on both CPU and GPU, but this while_loop is very slow. So, if we ever need to call falsefun that iteration is significantly slower than SVD on CPU. For GPU, I guess the execution of this while_loop is not that slow so we saw a speed-up with QR (not single QR but overall solution time decreased). Is there a better way to do this part?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really if you want the exact solution to the subproblem. You can get an approximate solution using the dogleg or subspace method but I've found those don't perform as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into ways of optimizing how to decompose Ji given that you know decomposition J=QR. I am not sure but there should be some linear algebra trick we can use

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are methods for updating a QR factorization when you add rows. Suppose we have

$$ QR = J $$

what we want is

$$ \tilde{Q} \tilde{R} = \begin{pmatrix} J \\ \alpha I \end{pmatrix} $$

for different values of $\alpha$. scipy has functions to do this, though jax doesn't so we'd need to implement ourselves. The ones I'm familiar with are sequential (ie, adding 1 row at a time) so probably aren't great on GPU but there are probably smarter ways. You can look into damped least squares, updating qr factorizations etc.

@YigitElma
Copy link
Collaborator

YigitElma commented Jun 18, 2024

For reference, here are some profiling on CPU, (L=M=N=10)

Using method: lsq-exact
   Iteration     Total nfev        Cost      Cost reduction    Step norm     Optimality   
       0              1          5.894e-03                                    8.717e-01   
init QR:  3.289091110229492
QR: 39.345173358917236 s,  # loops: 10
QR: 14.585150241851807 s,  # loops: 3
QR: 38.896700859069824 s,  # loops: 10
       1              4          2.535e-03      3.359e-03      1.513e-01      2.307e+00   
init QR:  3.139089345932007
QR: 0.00012564659118652344 s,  # loops: 0
       2              5          1.514e-04      2.384e-03      4.853e-02      3.162e-01   
init QR:  3.1483688354492188
QR: 0.00010204315185546875 s,  # loops: 0
       3              6          1.398e-05      1.374e-04      3.380e-02      1.564e-01   
init QR:  3.1345150470733643
QR: 0.0001232624053955078 s,  # loops: 0
       4              7          9.906e-07      1.299e-05      2.334e-02      4.424e-02   
init QR:  3.137814998626709
QR: 9.012222290039062e-05 s,  # loops: 0
       5              8          4.917e-08      9.414e-07      1.382e-02      7.766e-03   
init QR:  3.1379806995391846
QR: 0.00010609626770019531 s,  # loops: 0
       6              9          8.415e-09      4.075e-08      1.352e-02      4.155e-03   
init QR:  3.1388354301452637
QR: 0.00010442733764648438 s,  # loops: 0
       7             10          5.846e-10      7.830e-09      7.161e-03      1.010e-03   
init QR:  3.1450040340423584
QR: 0.00011539459228515625 s,  # loops: 0
       8             11          5.335e-10      5.109e-11      7.642e-03      7.996e-04   
init QR:  3.1360228061676025
QR: 9.322166442871094e-05 s,  # loops: 0
       9             12          2.397e-12      5.311e-10      7.502e-04      3.237e-05   
init QR:  3.1446845531463623
QR: 17.652892112731934 s,  # loops: 4
QR: 21.25301957130432 s,  # loops: 5
      10             14          2.536e-13      2.143e-12      3.222e-04      1.046e-05  
Using method: lsq-exact
   Iteration     Total nfev        Cost      Cost reduction    Step norm     Optimality
       0              1          5.894e-03                                    8.717e-01
init SVD:  5.595107555389404
SVD:  0.3101840019226074
SVD:  0.2919440269470215
       1              3          1.261e-03      4.633e-03      1.270e-01      1.662e+00
init SVD:  5.570924282073975
SVD:  0.0752408504486084
       2              4          6.245e-05      1.199e-03      3.385e-02      2.529e-01
init SVD:  5.5907557010650635
SVD:  0.07500481605529785
       3              5          6.382e-06      5.607e-05      2.447e-02      1.033e-01
init SVD:  5.571263790130615
SVD:  0.07488608360290527
       4              6          3.628e-07      6.020e-06      1.802e-02      2.413e-02
init SVD:  5.565975189208984
SVD:  0.07528972625732422
       5              7          1.849e-08      3.443e-07      1.107e-02      3.971e-03
init SVD:  5.581035137176514
SVD:  0.0754847526550293
       6              8          1.640e-09      1.685e-08      8.119e-03      1.303e-03
init SVD:  5.572343826293945
SVD:  0.07513928413391113
       7              9          6.162e-10      1.024e-09      7.606e-03      8.922e-04
init SVD:  5.565074920654297
SVD:  0.07505941390991211
       8             10          1.497e-11      6.012e-10      2.666e-03      1.343e-04
init SVD:  5.571868419647217
SVD:  0.07536005973815918
SVD:  0.07651829719543457
       9             12          4.644e-12      1.032e-11      2.045e-03      8.266e-05
init SVD:  5.567831993103027
SVD:  0.07741618156433105
      10             13          4.100e-12      5.446e-13      1.811e-03      8.348e-05

init SVD/QR are the first decomposition before trust_region_step_exact_ function (I moved it for QR) and the rest are the time it takes each trust_region_step_exact_ call.

@f0uriest
Copy link
Member Author

Here are some timings on cpu and gpu for just a single SVD/QR:

GPU:

M=N=6
svd
140 ms ± 222 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
qr
11.7 ms ± 2.92 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
cholesky
2.14 ms ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

M=N=8
svd
519 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
qr
40.1 ms ± 8.16 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cholesky
5.22 ms ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

M=N=10
svd
1.87 s ± 797 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
qr
130 ms ± 1.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
cholesky
14.8 ms ± 33.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

CPU:

M=N=6
svd
383 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
qr
88.5 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
cholesky
7.22 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

M=N=8
svd
4.26 s ± 10.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
qr
700 ms ± 4.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
cholesky
30 ms ± 540 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

M=N=10
svd
24.5 s ± 149 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
qr
4.19 s ± 9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
cholesky
139 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

looks like a single QR is ~10x faster on GPU, ~5x faster on CPU.

So if we're taking full newton steps, QR is always better. But if the trust region constraint is binding its less clear cut, and which is faster will depend on the hardware and how many steps are needed in the rootfinding

@f0uriest f0uriest marked this pull request as ready for review June 21, 2024 18:25
@f0uriest f0uriest requested a review from rahulgaur104 June 21, 2024 18:27
@rahulgaur104
Copy link
Collaborator

Should there be a test comparing SVD and QR?

def truefun(*_):
return p_newton, False, 0.0

def falsefun(*_):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is what's used if there is no exact solution within the trust region? could you give me a general overview of the function? I can't really tell what's going on

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically we first check to see if the full newton step is within the trust region. If it is, we use that (truefun).

Otherwise, we do some rootfinding in the regularization parameter alpha to find the regularized step that lies on the boundary of the trust region (falsefun).


return p, True, alpha

return cond(jnp.linalg.norm(p_newton) <= trust_radius, truefun, falsefun, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do truefun and falsefun use the *_ that is defined in within this trust_region_step_exact_qr function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, that's just a dummy variable to eat up the other return values from the loop. The only output we need from the loop is the final value of alpha.


@functools.partial(jit, static_argnames="lower")
def solve_triangular_regularized(R, b, lower=False):
"""Solve Rx=b for triangular, possibly rank deficient R.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring for args necessary? (i.e. I don't know what lower is)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It solves Rx=b for lower or upper triangular R matrix

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not public, so I don't think a full doc is necessary. It's basically the same as jax.scipy.linalg.solve_triangular but with a bit of pre-processing.

Copy link
Collaborator

@YigitElma YigitElma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think that we can take p_newton calculation out of the inner while loop. I assume there will be a follow-up PR about QR to make falsefun faster (at least I plan to work on that), so I can try some additional optimizations there. For now, this version is fine and works better than SVD on GPU.

@f0uriest f0uriest merged commit 1df51c4 into master Jul 18, 2024
18 checks passed
@f0uriest f0uriest deleted the rc/qr branch July 18, 2024 21:45
ddudt added a commit that referenced this pull request Aug 21, 2024
Resolves #1078 

Some performance improvements for QR decomposition used in optimization
which was first introduced in #1050.

- Take the `p_newton` calculation out of inner while loop, since it is
basically calculating the same QR over and over again
- ~Use proper QR update procedure for the `falsefun` in
`trust_region_step_exact_qr`. That is we already now QR decomposition of
`J=QR`, if we stack a diagonal matrix `aI` to `J` then instead of taking
the whole QR decomposition again, there is a more clever way of updating
the QR.There are methods for updating a QR factorization when you add
rows. Suppose we have~

$$
QR = J
$$

what we want is 

$$
\tilde{Q} \tilde{R} = \begin{pmatrix} J \\ 
\alpha I \end{pmatrix}
$$

The QR update procedure can be implemented on a later PR with
Householder matrices, but for now, it seems a bit inefficient to
implement using JAX since QR is calculated by Fortran package LAPACK on
Scipy and Jax, our custom QR'ish thing will be slow.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Try using QR decomp for least squares trust region subproblem
7 participants