From 56f5585bdba3e1fde5c13992691d6b630e738de9 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 01:12:04 -0500 Subject: [PATCH 01/18] make degenerate constraint check faster --- desc/objectives/utils.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 161c3f057e..e8b9d8cb4a 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -98,27 +98,15 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # if the entries of b aren't the same then the constraints are actually # incompatible and so we will leave those to be caught later. A_augmented = np.hstack([A, np.reshape(b, (A.shape[0], 1))]) - row_idx_to_delete = np.array([], dtype=int) - for row_idx in range(A_augmented.shape[0]): - # find all rows equal to this row - rows_equal_to_this_row = np.where( - np.all(A_augmented[row_idx, :] == A_augmented, axis=1) - )[0] - # find the rows equal to this row that are not this row - rows_equal_to_this_row_but_not_this_row = rows_equal_to_this_row[ - rows_equal_to_this_row != row_idx - ] - # if there are rows equal to this row that aren't this row, AND this particular - # row has not already been detected as a duplicate of an earlier one and slated - # for deletion, add the duplicate row indices to the array of - # rows to be deleted - if ( - rows_equal_to_this_row_but_not_this_row.size - and row_idx not in row_idx_to_delete - ): - row_idx_to_delete = np.append(row_idx_to_delete, rows_equal_to_this_row[1:]) - # delete the affected rows, and also the corresponding rows of b - A_augmented = np.delete(A_augmented, row_idx_to_delete, axis=0) + + # Find unique rows of A_augmented + unique_rows, unique_indices = np.unique(A_augmented, axis=0, return_index=True) + + # Sort the indices to preserve the order of appearance + unique_indices = np.sort(unique_indices) + + # Extract the unique rows + A_augmented = A_augmented[unique_indices] A = A_augmented[:, :-1] b = np.atleast_1d(A_augmented[:, -1].squeeze()) From 8486bfcdd11e2362641d8e6f3f8e70497a22e6f5 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 01:45:11 -0500 Subject: [PATCH 02/18] use jnp instead of np --- desc/objectives/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index e8b9d8cb4a..3506830eed 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -97,14 +97,16 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # which are duplicate rows of A that also have duplicate entries of b, # if the entries of b aren't the same then the constraints are actually # incompatible and so we will leave those to be caught later. - A_augmented = np.hstack([A, np.reshape(b, (A.shape[0], 1))]) + A_augmented = jnp.hstack([A, jnp.reshape(b, (A.shape[0], 1))]) # Find unique rows of A_augmented - unique_rows, unique_indices = np.unique(A_augmented, axis=0, return_index=True) + unique_rows, unique_indices = jnp.unique(A_augmented, axis=0, return_index=True) # Sort the indices to preserve the order of appearance - unique_indices = np.sort(unique_indices) + unique_indices = jnp.sort(unique_indices) + # while loop has problems updating JAX arrays, convert them to numpy arrays + A_augmented = np.array(A_augmented) # Extract the unique rows A_augmented = A_augmented[unique_indices] A = A_augmented[:, :-1] @@ -114,9 +116,6 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa indices_row = np.arange(A.shape[0]) indices_idx = np.arange(A.shape[1]) - # while loop has problems updating JAX arrays, convert them to numpy arrays - A = np.array(A) - b = np.array(b) while len(np.where(np.count_nonzero(A, axis=1) == 1)[0]): # fixed just means there is a single element in A, so A_ij*x_j = b_i fixed_rows = np.where(np.count_nonzero(A, axis=1) == 1)[0] From 487c5809de268a9e78e9b86f85cd76e4cac6d32e Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 12:29:45 -0500 Subject: [PATCH 03/18] back to numpy version --- desc/objectives/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 3506830eed..c84a609c31 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -97,13 +97,13 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # which are duplicate rows of A that also have duplicate entries of b, # if the entries of b aren't the same then the constraints are actually # incompatible and so we will leave those to be caught later. - A_augmented = jnp.hstack([A, jnp.reshape(b, (A.shape[0], 1))]) + A_augmented = np.hstack([A, np.reshape(b, (A.shape[0], 1))]) # Find unique rows of A_augmented - unique_rows, unique_indices = jnp.unique(A_augmented, axis=0, return_index=True) + unique_rows, unique_indices = np.unique(A_augmented, axis=0, return_index=True) # Sort the indices to preserve the order of appearance - unique_indices = jnp.sort(unique_indices) + unique_indices = np.sort(unique_indices) # while loop has problems updating JAX arrays, convert them to numpy arrays A_augmented = np.array(A_augmented) From 0bb978ee63b3d64452c28e19f0950899cdfe69b0 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 13:46:01 -0500 Subject: [PATCH 04/18] add benchmark, update unmarked test check --- devtools/check_unmarked_tests.sh | 2 +- tests/benchmarks/benchmark_cpu_small.py | 26 +++++++++++++++++++++++++ tests/benchmarks/benchmark_gpu_small.py | 26 +++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/devtools/check_unmarked_tests.sh b/devtools/check_unmarked_tests.sh index ff51bd1651..e031504e43 100755 --- a/devtools/check_unmarked_tests.sh +++ b/devtools/check_unmarked_tests.sh @@ -10,7 +10,7 @@ start_time=$(date +%s) echo "Files to check: $@" # Collect unmarked tests for the specific file and suppress errors -unmarked=$(pytest "$@" --collect-only -m "not unit and not regression" -q 2> /dev/null | head -n -2) +unmarked=$(pytest "$@" --collect-only -m "not unit and not regression and not benchmark" -q 2> /dev/null | head -n -2) # Count the number of unmarked tests found, ignoring empty lines num_unmarked=$(echo "$unmarked" | sed '/^\s*$/d' | wc -l) diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index a9e29be420..2d9b8c1acd 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -424,3 +424,29 @@ def run(eq): eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1) + + +@pytest.mark.slow +@pytest.mark.benchmark +def test_LinearConstraintProjection_build(benchmark): + """Benchmark LinearConstraintProjection build.""" + jax.clear_caches() + eq = desc.examples.get("W7-X") + + obj = ObjectiveFunction(ForceBalance(eq)) + con = get_fixed_boundary_constraints(eq) + con = maybe_add_self_consistency(eq, con) + con = ObjectiveFunction(con) + obj.build() + con.build() + + def run(obj, con): + lc = LinearConstraintProjection(obj, con) + lc.build() + + benchmark.pedantic( + run, + args=(obj, con), + rounds=10, + iterations=1, + ) diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index 921a4e4451..ef1f69c687 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -424,3 +424,29 @@ def run(eq): eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1) + + +@pytest.mark.slow +@pytest.mark.benchmark +def test_LinearConstraintProjection_build(benchmark): + """Benchmark LinearConstraintProjection build.""" + jax.clear_caches() + eq = desc.examples.get("W7-X") + + obj = ObjectiveFunction(ForceBalance(eq)) + con = get_fixed_boundary_constraints(eq) + con = maybe_add_self_consistency(eq, con) + con = ObjectiveFunction(con) + obj.build() + con.build() + + def run(obj, con): + lc = LinearConstraintProjection(obj, con) + lc.build() + + benchmark.pedantic( + run, + args=(obj, con), + rounds=10, + iterations=1, + ) From 27260e40a3cafbace95e09eea381fbd54db955d7 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 19:18:06 -0500 Subject: [PATCH 05/18] try qr for null-space --- desc/objectives/utils.py | 7 ++++--- desc/utils.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index c84a609c31..04c1d417f8 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -7,7 +7,7 @@ from desc.backend import cond, jit, jnp, logsumexp, put from desc.io import IOAble -from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif +from desc.utils import Index, errorif, flatten_list, qr_inv_null, unique_list, warnif def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa: C901 @@ -172,11 +172,12 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # null space & particular solution A = A * D[None, unfixed_idx] if A.size: - A_inv, Z = svd_inv_null(A) + x_p, Z = qr_inv_null(A, b) + xp = put(xp, unfixed_idx, x_p) else: A_inv = A.T Z = np.eye(A.shape[1]) - xp = put(xp, unfixed_idx, A_inv @ b) + xp = put(xp, unfixed_idx, A_inv @ b) xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx]) # cast to jnp arrays xp = jnp.asarray(xp) diff --git a/desc/utils.py b/desc/utils.py index 50c1db1b54..02cd2004d3 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -9,7 +9,7 @@ from scipy.special import factorial from termcolor import colored -from desc.backend import flatnonzero, fori_loop, jit, jnp, take +from desc.backend import flatnonzero, fori_loop, jit, jnp, qr, take class Timer: @@ -431,6 +431,40 @@ def svd_inv_null(A): return Ainv, Z +def qr_inv_null(A, b, tol=1e-10): + """Compute pseudo-inverse and null space of a matrix using QR. + + Parameters + ---------- + A : ndarray + Matrix to invert and find null space of. + b : ndarray + Right-hand side of Ax = b. + + Returns + ------- + x_p : ndarray + Particular solution to Ax = b. + Z : ndarray + Null space of A. + + """ + # Linear constraint matrix A is usually wide + # QR decomposition of A^T + Q, R = qr(A.T) + # Determine rank + diag = jnp.abs(jnp.diag(R)) + rank = jnp.sum(diag > tol) + + R1 = R[:rank, :rank] + Q1 = Q[:, :rank] + + # Null space is columns of Q[:, rank:] + Z = Q[:, rank:] + x_p = Q1 @ jnp.linalg.solve(R1.T, b) + return x_p, Z + + def combination_permutation(m, n, equals=True): """Compute all m-tuples of non-negative ints that sum to less than or equal to n. From 6f618b9382c82943060b11d44e664bc711f4a857 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 19:20:17 -0500 Subject: [PATCH 06/18] remove redundant array conversion --- desc/objectives/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 04c1d417f8..c231b8f894 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -105,8 +105,6 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # Sort the indices to preserve the order of appearance unique_indices = np.sort(unique_indices) - # while loop has problems updating JAX arrays, convert them to numpy arrays - A_augmented = np.array(A_augmented) # Extract the unique rows A_augmented = A_augmented[unique_indices] A = A_augmented[:, :-1] From 3102b27145360f43746f0fb4c1020d5b02afbddf Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 19:28:32 -0500 Subject: [PATCH 07/18] use solve_triangular instead --- desc/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/desc/utils.py b/desc/utils.py index 02cd2004d3..7085107ca5 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -9,7 +9,7 @@ from scipy.special import factorial from termcolor import colored -from desc.backend import flatnonzero, fori_loop, jit, jnp, qr, take +from desc.backend import flatnonzero, fori_loop, jit, jnp, qr, solve_triangular, take class Timer: @@ -461,7 +461,7 @@ def qr_inv_null(A, b, tol=1e-10): # Null space is columns of Q[:, rank:] Z = Q[:, rank:] - x_p = Q1 @ jnp.linalg.solve(R1.T, b) + x_p = Q1 @ solve_triangular(R1.T, b, lower=True) return x_p, Z From 70bdc81cd98fb79ea4ab504f10991961a29c55b3 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 21:27:48 -0500 Subject: [PATCH 08/18] add check for rank 0 case --- desc/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/desc/utils.py b/desc/utils.py index 7085107ca5..a1a5cd08a2 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -461,7 +461,8 @@ def qr_inv_null(A, b, tol=1e-10): # Null space is columns of Q[:, rank:] Z = Q[:, rank:] - x_p = Q1 @ solve_triangular(R1.T, b, lower=True) + # If rank is 0, then there is no particular solution + x_p = Q1 @ solve_triangular(R1.T, b, lower=True) if rank != 0 else 0 return x_p, Z From 0c0fddb0c93e7c0ecd027f6adf1bf45e93a88f4f Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 22:54:38 -0500 Subject: [PATCH 09/18] refactor solve_fixed_iter test to compiled and first to prevent huge standard deviation --- tests/benchmarks/benchmark_cpu_small.py | 19 ++++++++++++++++++- tests/benchmarks/benchmark_gpu_small.py | 19 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index 2d9b8c1acd..9285a96851 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -411,6 +411,22 @@ def run(x): benchmark.pedantic(run, args=(x,), rounds=15, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_solve_fixed_iter_compiled(benchmark): + """Benchmark running eq.solve for fixed iteration count.""" + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + def run(eq): + eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) + + benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter(benchmark): @@ -421,9 +437,10 @@ def test_solve_fixed_iter(benchmark): eq.change_resolution(6, 6, 6, 12, 12, 12) def run(eq): + jax.clear_caches() eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1) + benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) @pytest.mark.slow diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index ef1f69c687..bc8a6b7af2 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -411,6 +411,22 @@ def run(x): benchmark.pedantic(run, args=(x,), rounds=15, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_solve_fixed_iter_compiled(benchmark): + """Benchmark running eq.solve for fixed iteration count after compilation.""" + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + def run(eq): + eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) + + benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter(benchmark): @@ -421,9 +437,10 @@ def test_solve_fixed_iter(benchmark): eq.change_resolution(6, 6, 6, 12, 12, 12) def run(eq): + jax.clear_caches() eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1) + benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) @pytest.mark.slow From 474c628a29f3935535c6c24a5ab0f2c6b753255a Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 23:31:27 -0500 Subject: [PATCH 10/18] add setup to benchmarks to prevent high standard deviation --- tests/benchmarks/benchmark_cpu_small.py | 53 +++++++++++++++---------- tests/benchmarks/benchmark_gpu_small.py | 51 ++++++++++++++---------- 2 files changed, 63 insertions(+), 41 deletions(-) diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index 9285a96851..d150d9b499 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -414,48 +414,59 @@ def run(x): @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter_compiled(benchmark): - """Benchmark running eq.solve for fixed iteration count.""" - jax.clear_caches() - eq = desc.examples.get("ESTELL") - with pytest.warns(UserWarning, match="Reducing radial"): - eq.change_resolution(6, 6, 6, 12, 12, 12) - eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + """Benchmark running eq.solve for fixed iteration count after compilation.""" + + def setup(): + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + return eq def run(eq): eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) + benchmark.pedantic(run, setup=setup, rounds=5, iterations=1) @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter(benchmark): """Benchmark running eq.solve for fixed iteration count.""" - jax.clear_caches() - eq = desc.examples.get("ESTELL") - with pytest.warns(UserWarning, match="Reducing radial"): - eq.change_resolution(6, 6, 6, 12, 12, 12) + + def setup(): + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + + return eq def run(eq): jax.clear_caches() eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) + benchmark.pedantic(run, setup=setup, rounds=5, iterations=1) @pytest.mark.slow @pytest.mark.benchmark def test_LinearConstraintProjection_build(benchmark): """Benchmark LinearConstraintProjection build.""" - jax.clear_caches() - eq = desc.examples.get("W7-X") - obj = ObjectiveFunction(ForceBalance(eq)) - con = get_fixed_boundary_constraints(eq) - con = maybe_add_self_consistency(eq, con) - con = ObjectiveFunction(con) - obj.build() - con.build() + def setup(): + jax.clear_caches() + eq = desc.examples.get("W7-X") + + obj = ObjectiveFunction(ForceBalance(eq)) + con = get_fixed_boundary_constraints(eq) + con = maybe_add_self_consistency(eq, con) + con = ObjectiveFunction(con) + obj.build() + con.build() + return obj, con def run(obj, con): lc = LinearConstraintProjection(obj, con) @@ -463,7 +474,7 @@ def run(obj, con): benchmark.pedantic( run, - args=(obj, con), + setup=setup, rounds=10, iterations=1, ) diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index bc8a6b7af2..3afe4b6c04 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -415,47 +415,58 @@ def run(x): @pytest.mark.benchmark def test_solve_fixed_iter_compiled(benchmark): """Benchmark running eq.solve for fixed iteration count after compilation.""" - jax.clear_caches() - eq = desc.examples.get("ESTELL") - with pytest.warns(UserWarning, match="Reducing radial"): - eq.change_resolution(6, 6, 6, 12, 12, 12) - eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + def setup(): + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + return eq def run(eq): eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) + benchmark.pedantic(run, setup=setup, rounds=5, iterations=1) @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter(benchmark): """Benchmark running eq.solve for fixed iteration count.""" - jax.clear_caches() - eq = desc.examples.get("ESTELL") - with pytest.warns(UserWarning, match="Reducing radial"): - eq.change_resolution(6, 6, 6, 12, 12, 12) + + def setup(): + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + + return eq def run(eq): jax.clear_caches() eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) + benchmark.pedantic(run, setup=setup, rounds=5, iterations=1) @pytest.mark.slow @pytest.mark.benchmark def test_LinearConstraintProjection_build(benchmark): """Benchmark LinearConstraintProjection build.""" - jax.clear_caches() - eq = desc.examples.get("W7-X") - obj = ObjectiveFunction(ForceBalance(eq)) - con = get_fixed_boundary_constraints(eq) - con = maybe_add_self_consistency(eq, con) - con = ObjectiveFunction(con) - obj.build() - con.build() + def setup(): + jax.clear_caches() + eq = desc.examples.get("W7-X") + + obj = ObjectiveFunction(ForceBalance(eq)) + con = get_fixed_boundary_constraints(eq) + con = maybe_add_self_consistency(eq, con) + con = ObjectiveFunction(con) + obj.build() + con.build() + return obj, con def run(obj, con): lc = LinearConstraintProjection(obj, con) @@ -463,7 +474,7 @@ def run(obj, con): benchmark.pedantic( run, - args=(obj, con), + setup=setup, rounds=10, iterations=1, ) From 14f3018d4a9dc4c43dd820539f2069c8693ae290 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 23:52:53 -0500 Subject: [PATCH 11/18] fix setup return type --- tests/benchmarks/benchmark_cpu_small.py | 4 ++-- tests/benchmarks/benchmark_gpu_small.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index d150d9b499..2d55e67cab 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -423,7 +423,7 @@ def setup(): eq.change_resolution(6, 6, 6, 12, 12, 12) eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) - return eq + return (eq,) def run(eq): eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) @@ -442,7 +442,7 @@ def setup(): with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(6, 6, 6, 12, 12, 12) - return eq + return (eq,) def run(eq): jax.clear_caches() diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index 3afe4b6c04..d5b25a495e 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -423,7 +423,7 @@ def setup(): eq.change_resolution(6, 6, 6, 12, 12, 12) eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) - return eq + return (eq,) def run(eq): eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) @@ -442,7 +442,7 @@ def setup(): with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(6, 6, 6, 12, 12, 12) - return eq + return (eq,) def run(eq): jax.clear_caches() From 586de4daa87063cf68a1acff22ce4938ce5333d3 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Tue, 19 Nov 2024 00:22:48 -0500 Subject: [PATCH 12/18] try fix --- tests/benchmarks/benchmark_cpu_small.py | 6 +++--- tests/benchmarks/benchmark_gpu_small.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index 2d55e67cab..223c30fe67 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -423,7 +423,7 @@ def setup(): eq.change_resolution(6, 6, 6, 12, 12, 12) eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) - return (eq,) + return (eq,), {} def run(eq): eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) @@ -442,7 +442,7 @@ def setup(): with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(6, 6, 6, 12, 12, 12) - return (eq,) + return (eq,), {} def run(eq): jax.clear_caches() @@ -466,7 +466,7 @@ def setup(): con = ObjectiveFunction(con) obj.build() con.build() - return obj, con + return (obj, con), {} def run(obj, con): lc = LinearConstraintProjection(obj, con) diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index d5b25a495e..623c4da51a 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -423,7 +423,7 @@ def setup(): eq.change_resolution(6, 6, 6, 12, 12, 12) eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) - return (eq,) + return (eq,), {} def run(eq): eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) @@ -442,7 +442,7 @@ def setup(): with pytest.warns(UserWarning, match="Reducing radial"): eq.change_resolution(6, 6, 6, 12, 12, 12) - return (eq,) + return (eq,), {} def run(eq): jax.clear_caches() @@ -466,7 +466,7 @@ def setup(): con = ObjectiveFunction(con) obj.build() con.build() - return obj, con + return (obj, con), {} def run(obj, con): lc = LinearConstraintProjection(obj, con) From c59b58aa146a008f437734bb326bc619fa40ada1 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Tue, 19 Nov 2024 01:56:22 -0500 Subject: [PATCH 13/18] update docs --- desc/objectives/utils.py | 10 +++++----- desc/utils.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index c231b8f894..292d713e80 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -7,11 +7,11 @@ from desc.backend import cond, jit, jnp, logsumexp, put from desc.io import IOAble -from desc.utils import Index, errorif, flatten_list, qr_inv_null, unique_list, warnif +from desc.utils import Index, errorif, flatten_list, qr_xp_null, unique_list, warnif def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa: C901 - """Compute and factorize A to get pseudoinverse and nullspace. + """Compute and factorize A to get particular solution and nullspace. Given constraints of the form Ax=b, factorize A to find a particular solution xp and the null space Z st. Axp=b and AZ=0, so that the full space of solutions to @@ -170,7 +170,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # null space & particular solution A = A * D[None, unfixed_idx] if A.size: - x_p, Z = qr_inv_null(A, b) + x_p, Z = qr_xp_null(A, b) xp = put(xp, unfixed_idx, x_p) else: A_inv = A.T @@ -197,7 +197,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # If the error is very large, likely want to error out as # it probably is due to a real mistake instead of just numerical - # roundoff errors. + # round-off errors. np.testing.assert_allclose( y1, y2, @@ -208,7 +208,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa ) # else check with tighter tols and throw an error, these tolerances - # could be tripped due to just numerical roundoff or poor scaling between + # could be tripped due to just numerical round-off or poor scaling between # constraints, so don't want to error out but we do want to warn the user. atol = 3e-14 rtol = 3e-14 diff --git a/desc/utils.py b/desc/utils.py index a1a5cd08a2..0b5f4bd035 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -431,13 +431,13 @@ def svd_inv_null(A): return Ainv, Z -def qr_inv_null(A, b, tol=1e-10): - """Compute pseudo-inverse and null space of a matrix using QR. +def qr_xp_null(A, b, tol=1e-10): + """Compute null space of a matrix and particular solution Ax=b using QR. Parameters ---------- A : ndarray - Matrix to invert and find null space of. + Matrix to find null space of. b : ndarray Right-hand side of Ax = b. @@ -446,7 +446,7 @@ def qr_inv_null(A, b, tol=1e-10): x_p : ndarray Particular solution to Ax = b. Z : ndarray - Null space of A. + Null space of A such that AZ=0 and Z.T@Z=I. """ # Linear constraint matrix A is usually wide From 3190e41e12218d6b99f8f917f9bbbafc113c38f2 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 20 Nov 2024 23:29:47 -0500 Subject: [PATCH 14/18] update tolerance --- desc/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/desc/utils.py b/desc/utils.py index 0b5f4bd035..ff19bffa54 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -431,7 +431,7 @@ def svd_inv_null(A): return Ainv, Z -def qr_xp_null(A, b, tol=1e-10): +def qr_xp_null(A, b): """Compute null space of a matrix and particular solution Ax=b using QR. Parameters @@ -454,6 +454,7 @@ def qr_xp_null(A, b, tol=1e-10): Q, R = qr(A.T) # Determine rank diag = jnp.abs(jnp.diag(R)) + tol = np.finfo(A.dtype).eps * max(A.shape) * jnp.amax(diag) rank = jnp.sum(diag > tol) R1 = R[:rank, :rank] From 859f0671074d5b738ab6a38273a79c544d7b5b13 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 21 Nov 2024 00:13:28 -0500 Subject: [PATCH 15/18] use jax svd in svd_inv_null --- desc/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/desc/utils.py b/desc/utils.py index ff19bffa54..e7a91b0fdf 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -415,18 +415,19 @@ def svd_inv_null(A): Null space of A. """ - u, s, vh = np.linalg.svd(A, full_matrices=True) + u, s, vh = jnp.linalg.svd(A, full_matrices=True) M, N = u.shape[0], vh.shape[1] K = min(M, N) rcond = np.finfo(A.dtype).eps * max(M, N) - tol = np.amax(s) * rcond + tol = jnp.amax(s) * rcond large = s > tol - num = np.sum(large, dtype=int) + num = jnp.sum(large, dtype=int) uk = u[:, :K] vhk = vh[:K, :] - s = np.divide(1, s, where=large, out=s) - s[(~large,)] = 0 - Ainv = np.matmul(vhk.T, np.multiply(s[..., np.newaxis], uk.T)) + s = jnp.where(large, 1 / s, s) + s.shape + s = s.at[(~large,)].set(0) + Ainv = vhk.T @ jnp.diag(s) @ uk.T Z = vh[num:, :].T.conj() return Ainv, Z From a02addede51cb7151a1c4a26eeca1a3275cfcc85 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 21 Nov 2024 00:30:09 -0500 Subject: [PATCH 16/18] remove qr version, jax svd version is fast enough --- desc/objectives/utils.py | 7 +++---- desc/utils.py | 38 +------------------------------------- 2 files changed, 4 insertions(+), 41 deletions(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 292d713e80..49e9770bee 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -7,7 +7,7 @@ from desc.backend import cond, jit, jnp, logsumexp, put from desc.io import IOAble -from desc.utils import Index, errorif, flatten_list, qr_xp_null, unique_list, warnif +from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa: C901 @@ -170,12 +170,11 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # null space & particular solution A = A * D[None, unfixed_idx] if A.size: - x_p, Z = qr_xp_null(A, b) - xp = put(xp, unfixed_idx, x_p) + A_inv, Z = svd_inv_null(A, b) else: A_inv = A.T Z = np.eye(A.shape[1]) - xp = put(xp, unfixed_idx, A_inv @ b) + xp = put(xp, unfixed_idx, A_inv @ b) xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx]) # cast to jnp arrays xp = jnp.asarray(xp) diff --git a/desc/utils.py b/desc/utils.py index e7a91b0fdf..abf5e56665 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -9,7 +9,7 @@ from scipy.special import factorial from termcolor import colored -from desc.backend import flatnonzero, fori_loop, jit, jnp, qr, solve_triangular, take +from desc.backend import flatnonzero, fori_loop, jit, jnp, take class Timer: @@ -432,42 +432,6 @@ def svd_inv_null(A): return Ainv, Z -def qr_xp_null(A, b): - """Compute null space of a matrix and particular solution Ax=b using QR. - - Parameters - ---------- - A : ndarray - Matrix to find null space of. - b : ndarray - Right-hand side of Ax = b. - - Returns - ------- - x_p : ndarray - Particular solution to Ax = b. - Z : ndarray - Null space of A such that AZ=0 and Z.T@Z=I. - - """ - # Linear constraint matrix A is usually wide - # QR decomposition of A^T - Q, R = qr(A.T) - # Determine rank - diag = jnp.abs(jnp.diag(R)) - tol = np.finfo(A.dtype).eps * max(A.shape) * jnp.amax(diag) - rank = jnp.sum(diag > tol) - - R1 = R[:rank, :rank] - Q1 = Q[:, :rank] - - # Null space is columns of Q[:, rank:] - Z = Q[:, rank:] - # If rank is 0, then there is no particular solution - x_p = Q1 @ solve_triangular(R1.T, b, lower=True) if rank != 0 else 0 - return x_p, Z - - def combination_permutation(m, n, equals=True): """Compute all m-tuples of non-negative ints that sum to less than or equal to n. From ca3d967b39fcf39f3310c2beb1f89ed49c84b5c3 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 21 Nov 2024 00:36:00 -0500 Subject: [PATCH 17/18] fix typo --- desc/objectives/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index 49e9770bee..12ea6c0815 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -170,7 +170,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # null space & particular solution A = A * D[None, unfixed_idx] if A.size: - A_inv, Z = svd_inv_null(A, b) + A_inv, Z = svd_inv_null(A) else: A_inv = A.T Z = np.eye(A.shape[1]) From b116a8a07b2844cdf074614e4fba21edf1d9eb29 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 21 Nov 2024 01:16:29 -0500 Subject: [PATCH 18/18] clean up jax svd version --- desc/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/desc/utils.py b/desc/utils.py index abf5e56665..0bc4a0bfb1 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -424,9 +424,7 @@ def svd_inv_null(A): num = jnp.sum(large, dtype=int) uk = u[:, :K] vhk = vh[:K, :] - s = jnp.where(large, 1 / s, s) - s.shape - s = s.at[(~large,)].set(0) + s = jnp.where(large, 1 / s, 0) Ainv = vhk.T @ jnp.diag(s) @ uk.T Z = vh[num:, :].T.conj() return Ainv, Z