Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make factorize_linear_constraints faster #1374

Merged
merged 27 commits into from
Dec 3, 2024
Merged

Make factorize_linear_constraints faster #1374

merged 27 commits into from
Dec 3, 2024

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Nov 18, 2024

  • Replaces the for loop inside factorize_linear_constrains for degenerate constraints by np.unique along axis=0
  • Adds benchmark test for LinearConstraintProjection build. Considering that initialization of the optimization problem takes around 20 seconds, this is a good case to test every PR.
  • Refactors solve_fixed_iter benchmark to 2 parts, one with everything compiled, the other with cleared cache.
  • Changes null-space and particular solution calculations to use jax.numpy.linalg.svd.

#1300 added the check for degenerate constraints check to the factorize_linear_constraints function. However, the loop over rows and calling np.where which is basically another loop makes it very slow (at least on GPUs). I wouldn't expect it to be that slow but apparently, it takes more than actual null-space and a particular solution calculations. Given that this check is not that necessary for most cases, this is a big burden.

GPU Benchmarks

Here is the old version,
image

The new version with jax.svd and updated degenerate constraint handler,
image

@YigitElma
Copy link
Collaborator Author

YigitElma commented Nov 18, 2024

If anyone is interested here are the benchmarks for the second time factorize_linear_constraints is called (everything compiled)
image
image

And this is the script that I use,

eq = get("W7-X")

objective = ObjectiveFunction(ForceBalance(eq=eq))
constraints = get_fixed_boundary_constraints(eq=eq)
optimizer = Optimizer("lsq-exact")

eq, solver_outputs = eq.solve(
    objective=objective,
    constraints=constraints,
    optimizer=optimizer,
    maxiter=0,
    verbose=3,
    ftol=0,
    gtol=0,
    xtol=0,
)

@YigitElma YigitElma requested review from a team, rahulgaur104, f0uriest, ddudt, dpanici, kianorr, sinaatalay and unalmis and removed request for a team November 18, 2024 06:48
@YigitElma YigitElma self-assigned this Nov 18, 2024
@YigitElma YigitElma added performance New feature or request to make the code faster skip_changelog No need to update changelog on this PR labels Nov 18, 2024
Copy link

codecov bot commented Nov 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 95.57%. Comparing base (c8a4077) to head (b23b5e0).
Report is 7 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1374      +/-   ##
==========================================
- Coverage   95.58%   95.57%   -0.01%     
==========================================
  Files          96       96              
  Lines       24477    24477              
==========================================
- Hits        23396    23395       -1     
- Misses       1081     1082       +1     
Files with missing lines Coverage Δ
desc/objectives/utils.py 100.00% <100.00%> (ø)
desc/utils.py 90.79% <100.00%> (ø)

... and 2 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Nov 18, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     -0.48 +/- 3.39     | -2.89e-03 +/- 2.03e-02 |  5.95e-01 +/- 1.6e-02  |  5.98e-01 +/- 1.2e-02  |
 test_build_transform_fft_highres        |     -1.13 +/- 1.40     | -1.09e-02 +/- 1.34e-02 |  9.47e-01 +/- 9.3e-03  |  9.58e-01 +/- 9.6e-03  |
 test_equilibrium_init_lowres            |     -0.21 +/- 1.37     | -8.08e-03 +/- 5.18e-02 |  3.77e+00 +/- 3.5e-02  |  3.78e+00 +/- 3.8e-02  |
 test_objective_compile_atf              |     -0.90 +/- 3.74     | -7.11e-02 +/- 2.95e-01 |  7.80e+00 +/- 2.1e-01  |  7.87e+00 +/- 2.1e-01  |
 test_objective_compute_atf              |     +1.74 +/- 2.66     | +1.81e-04 +/- 2.76e-04 |  1.06e-02 +/- 1.6e-04  |  1.04e-02 +/- 2.3e-04  |
 test_objective_jac_atf                  |     +1.69 +/- 3.36     | +3.18e-02 +/- 6.33e-02 |  1.91e+00 +/- 5.5e-02  |  1.88e+00 +/- 3.1e-02  |
 test_perturb_1                          |     +0.58 +/- 2.43     | +8.03e-02 +/- 3.36e-01 |  1.39e+01 +/- 3.1e-01  |  1.38e+01 +/- 1.4e-01  |
 test_proximal_jac_atf                   |     +0.08 +/- 1.18     | +6.79e-03 +/- 9.56e-02 |  8.09e+00 +/- 7.0e-02  |  8.08e+00 +/- 6.5e-02  |
 test_proximal_freeb_compute             |     -0.90 +/- 0.96     | -1.78e-03 +/- 1.88e-03 |  1.95e-01 +/- 1.4e-03  |  1.97e-01 +/- 1.2e-03  |
 test_solve_fixed_iter_compiled          |     -0.18 +/- 1.40     | -3.07e-02 +/- 2.34e-01 |  1.67e+01 +/- 1.0e-01  |  1.67e+01 +/- 2.1e-01  |
 test_build_transform_fft_lowres         |     -3.96 +/- 7.76     | -2.13e-02 +/- 4.18e-02 |  5.17e-01 +/- 3.3e-02  |  5.39e-01 +/- 2.6e-02  |
 test_equilibrium_init_medres            |     +6.72 +/- 7.21     | +2.73e-01 +/- 2.93e-01 |  4.34e+00 +/- 1.9e-01  |  4.07e+00 +/- 2.3e-01  |
 test_equilibrium_init_highres           |     +0.21 +/- 2.38     | +1.13e-02 +/- 1.27e-01 |  5.34e+00 +/- 8.4e-02  |  5.33e+00 +/- 9.5e-02  |
 test_objective_compile_dshape_current   |     -1.19 +/- 4.81     | -4.54e-02 +/- 1.84e-01 |  3.78e+00 +/- 1.8e-01  |  3.83e+00 +/- 4.1e-02  |
 test_objective_compute_dshape_current   |     -0.11 +/- 1.61     | -4.18e-06 +/- 5.87e-05 |  3.64e-03 +/- 3.6e-05  |  3.64e-03 +/- 4.6e-05  |
 test_objective_jac_dshape_current       |     +0.08 +/- 9.74     | +3.26e-05 +/- 3.75e-03 |  3.86e-02 +/- 3.6e-03  |  3.85e-02 +/- 1.0e-03  |
 test_perturb_2                          |     +0.19 +/- 3.61     | +3.56e-02 +/- 6.84e-01 |  1.90e+01 +/- 4.7e-01  |  1.89e+01 +/- 5.0e-01  |
 test_proximal_freeb_jac                 |     -1.35 +/- 2.72     | -1.02e-01 +/- 2.06e-01 |  7.48e+00 +/- 1.5e-01  |  7.58e+00 +/- 1.4e-01  |
 test_solve_fixed_iter                   |     -3.28 +/- 2.19     | -9.35e-01 +/- 6.25e-01 |  2.76e+01 +/- 4.2e-01  |  2.85e+01 +/- 4.7e-01  |
+test_LinearConstraintProjection_build   |    -23.15 +/- 1.44     | -6.79e+00 +/- 4.24e-01 |  2.25e+01 +/- 3.7e-01  |  2.93e+01 +/- 2.0e-01  |

dpanici
dpanici previously approved these changes Nov 18, 2024
@YigitElma YigitElma marked this pull request as draft November 18, 2024 17:09
desc/utils.py Outdated Show resolved Hide resolved
desc/objectives/utils.py Outdated Show resolved Hide resolved
desc/utils.py Outdated Show resolved Hide resolved
@YigitElma YigitElma requested a review from dpanici November 19, 2024 20:28
desc/utils.py Outdated Show resolved Hide resolved
desc/utils.py Outdated Show resolved Hide resolved
@dpanici
Copy link
Collaborator

dpanici commented Nov 20, 2024

Try throwing execute on cpu decorator

@YigitElma
Copy link
Collaborator Author

YigitElma commented Nov 20, 2024

New results with 32 cores. Overall effect is the same. When Rory said 2-3 seconds, those were for low res appearing in solve_continuation, My script is different and uses full res.

These are SVD and for loop for degeneracy, on CPU and GPU respectively, (and I show the second time it is called so everything is compiled)
image
image
GPU version is slower for the first run, but it seems to be the same on later.

Since QR is not done yet, I won't share the results of it but it is impressive on GPU :) So here are the results for SVD+np.unique on same settings
image
image

Since we use np.linalg.svd (!?) there is no performance gain on SVD on GPU. They perform similar for so need to add execute_on_cpu flag.

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does using jnp make it faster on gpu now?

@YigitElma
Copy link
Collaborator Author

does using jnp make it faster on gpu now?

Yes, I have updated the benchmarks in the description. The CPU is the same, but it is much faster on the GPU. Especially once the SVD is compiled (yes some internal jax functions get compiled even outside of jit), the speed up is extreme for high res.

@YigitElma YigitElma merged commit 02aced6 into master Dec 3, 2024
25 checks passed
@YigitElma YigitElma deleted the yge/degenerate branch December 3, 2024 22:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance New feature or request to make the code faster skip_changelog No need to update changelog on this PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants