-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make factorize_linear_constraints
faster
#1374
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1374 +/- ##
==========================================
- Coverage 95.58% 95.57% -0.01%
==========================================
Files 96 96
Lines 24477 24477
==========================================
- Hits 23396 23395 -1
- Misses 1081 1082 +1
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | -0.48 +/- 3.39 | -2.89e-03 +/- 2.03e-02 | 5.95e-01 +/- 1.6e-02 | 5.98e-01 +/- 1.2e-02 |
test_build_transform_fft_highres | -1.13 +/- 1.40 | -1.09e-02 +/- 1.34e-02 | 9.47e-01 +/- 9.3e-03 | 9.58e-01 +/- 9.6e-03 |
test_equilibrium_init_lowres | -0.21 +/- 1.37 | -8.08e-03 +/- 5.18e-02 | 3.77e+00 +/- 3.5e-02 | 3.78e+00 +/- 3.8e-02 |
test_objective_compile_atf | -0.90 +/- 3.74 | -7.11e-02 +/- 2.95e-01 | 7.80e+00 +/- 2.1e-01 | 7.87e+00 +/- 2.1e-01 |
test_objective_compute_atf | +1.74 +/- 2.66 | +1.81e-04 +/- 2.76e-04 | 1.06e-02 +/- 1.6e-04 | 1.04e-02 +/- 2.3e-04 |
test_objective_jac_atf | +1.69 +/- 3.36 | +3.18e-02 +/- 6.33e-02 | 1.91e+00 +/- 5.5e-02 | 1.88e+00 +/- 3.1e-02 |
test_perturb_1 | +0.58 +/- 2.43 | +8.03e-02 +/- 3.36e-01 | 1.39e+01 +/- 3.1e-01 | 1.38e+01 +/- 1.4e-01 |
test_proximal_jac_atf | +0.08 +/- 1.18 | +6.79e-03 +/- 9.56e-02 | 8.09e+00 +/- 7.0e-02 | 8.08e+00 +/- 6.5e-02 |
test_proximal_freeb_compute | -0.90 +/- 0.96 | -1.78e-03 +/- 1.88e-03 | 1.95e-01 +/- 1.4e-03 | 1.97e-01 +/- 1.2e-03 |
test_solve_fixed_iter_compiled | -0.18 +/- 1.40 | -3.07e-02 +/- 2.34e-01 | 1.67e+01 +/- 1.0e-01 | 1.67e+01 +/- 2.1e-01 |
test_build_transform_fft_lowres | -3.96 +/- 7.76 | -2.13e-02 +/- 4.18e-02 | 5.17e-01 +/- 3.3e-02 | 5.39e-01 +/- 2.6e-02 |
test_equilibrium_init_medres | +6.72 +/- 7.21 | +2.73e-01 +/- 2.93e-01 | 4.34e+00 +/- 1.9e-01 | 4.07e+00 +/- 2.3e-01 |
test_equilibrium_init_highres | +0.21 +/- 2.38 | +1.13e-02 +/- 1.27e-01 | 5.34e+00 +/- 8.4e-02 | 5.33e+00 +/- 9.5e-02 |
test_objective_compile_dshape_current | -1.19 +/- 4.81 | -4.54e-02 +/- 1.84e-01 | 3.78e+00 +/- 1.8e-01 | 3.83e+00 +/- 4.1e-02 |
test_objective_compute_dshape_current | -0.11 +/- 1.61 | -4.18e-06 +/- 5.87e-05 | 3.64e-03 +/- 3.6e-05 | 3.64e-03 +/- 4.6e-05 |
test_objective_jac_dshape_current | +0.08 +/- 9.74 | +3.26e-05 +/- 3.75e-03 | 3.86e-02 +/- 3.6e-03 | 3.85e-02 +/- 1.0e-03 |
test_perturb_2 | +0.19 +/- 3.61 | +3.56e-02 +/- 6.84e-01 | 1.90e+01 +/- 4.7e-01 | 1.89e+01 +/- 5.0e-01 |
test_proximal_freeb_jac | -1.35 +/- 2.72 | -1.02e-01 +/- 2.06e-01 | 7.48e+00 +/- 1.5e-01 | 7.58e+00 +/- 1.4e-01 |
test_solve_fixed_iter | -3.28 +/- 2.19 | -9.35e-01 +/- 6.25e-01 | 2.76e+01 +/- 4.2e-01 | 2.85e+01 +/- 4.7e-01 |
+test_LinearConstraintProjection_build | -23.15 +/- 1.44 | -6.79e+00 +/- 4.24e-01 | 2.25e+01 +/- 3.7e-01 | 2.93e+01 +/- 2.0e-01 | |
Try throwing execute on cpu decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does using jnp make it faster on gpu now?
Yes, I have updated the benchmarks in the description. The CPU is the same, but it is much faster on the GPU. Especially once the SVD is compiled (yes some internal jax functions get compiled even outside of jit), the speed up is extreme for high res. |
factorize_linear_constrains
for degenerate constraints bynp.unique
alongaxis=0
LinearConstraintProjection
build. Considering that initialization of the optimization problem takes around 20 seconds, this is a good case to test every PR.solve_fixed_iter
benchmark to 2 parts, one with everything compiled, the other with cleared cache.jax.numpy.linalg.svd
.#1300 added the check for degenerate constraints check to the
factorize_linear_constraints
function. However, the loop over rows and callingnp.where
which is basically another loop makes it very slow (at least on GPUs). I wouldn't expect it to be that slow but apparently, it takes more than actual null-space and a particular solution calculations. Given that this check is not that necessary for most cases, this is a big burden.GPU Benchmarks
Here is the old version,
The new version with
jax.svd
and updated degenerate constraint handler,