diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index d02e70bc73..8e0d6c5f34 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -11,7 +11,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa: C901 - """Compute and factorize A to get pseudoinverse and nullspace. + """Compute and factorize A to get particular solution and nullspace. Given constraints of the form Ax=b, factorize A to find a particular solution xp and the null space Z st. Axp=b and AZ=0, so that the full space of solutions to @@ -98,27 +98,15 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # if the entries of b aren't the same then the constraints are actually # incompatible and so we will leave those to be caught later. A_augmented = np.hstack([A, np.reshape(b, (A.shape[0], 1))]) - row_idx_to_delete = np.array([], dtype=int) - for row_idx in range(A_augmented.shape[0]): - # find all rows equal to this row - rows_equal_to_this_row = np.where( - np.all(A_augmented[row_idx, :] == A_augmented, axis=1) - )[0] - # find the rows equal to this row that are not this row - rows_equal_to_this_row_but_not_this_row = rows_equal_to_this_row[ - rows_equal_to_this_row != row_idx - ] - # if there are rows equal to this row that aren't this row, AND this particular - # row has not already been detected as a duplicate of an earlier one and slated - # for deletion, add the duplicate row indices to the array of - # rows to be deleted - if ( - rows_equal_to_this_row_but_not_this_row.size - and row_idx not in row_idx_to_delete - ): - row_idx_to_delete = np.append(row_idx_to_delete, rows_equal_to_this_row[1:]) - # delete the affected rows, and also the corresponding rows of b - A_augmented = np.delete(A_augmented, row_idx_to_delete, axis=0) + + # Find unique rows of A_augmented + unique_rows, unique_indices = np.unique(A_augmented, axis=0, return_index=True) + + # Sort the indices to preserve the order of appearance + unique_indices = np.sort(unique_indices) + + # Extract the unique rows + A_augmented = A_augmented[unique_indices] A = A_augmented[:, :-1] b = np.atleast_1d(A_augmented[:, -1].squeeze()) @@ -126,9 +114,6 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa indices_row = np.arange(A.shape[0]) indices_idx = np.arange(A.shape[1]) - # while loop has problems updating JAX arrays, convert them to numpy arrays - A = np.array(A) - b = np.array(b) while len(np.where(np.count_nonzero(A, axis=1) == 1)[0]): # fixed just means there is a single element in A, so A_ij*x_j = b_i fixed_rows = np.where(np.count_nonzero(A, axis=1) == 1)[0] @@ -211,7 +196,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa # If the error is very large, likely want to error out as # it probably is due to a real mistake instead of just numerical - # roundoff errors. + # round-off errors. np.testing.assert_allclose( y1, y2, @@ -222,7 +207,7 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa ) # else check with tighter tols and throw an error, these tolerances - # could be tripped due to just numerical roundoff or poor scaling between + # could be tripped due to just numerical round-off or poor scaling between # constraints, so don't want to error out but we do want to warn the user. atol = 3e-14 rtol = 3e-14 diff --git a/desc/utils.py b/desc/utils.py index 50c1db1b54..0bc4a0bfb1 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -415,18 +415,17 @@ def svd_inv_null(A): Null space of A. """ - u, s, vh = np.linalg.svd(A, full_matrices=True) + u, s, vh = jnp.linalg.svd(A, full_matrices=True) M, N = u.shape[0], vh.shape[1] K = min(M, N) rcond = np.finfo(A.dtype).eps * max(M, N) - tol = np.amax(s) * rcond + tol = jnp.amax(s) * rcond large = s > tol - num = np.sum(large, dtype=int) + num = jnp.sum(large, dtype=int) uk = u[:, :K] vhk = vh[:K, :] - s = np.divide(1, s, where=large, out=s) - s[(~large,)] = 0 - Ainv = np.matmul(vhk.T, np.multiply(s[..., np.newaxis], uk.T)) + s = jnp.where(large, 1 / s, 0) + Ainv = vhk.T @ jnp.diag(s) @ uk.T Z = vh[num:, :].T.conj() return Ainv, Z diff --git a/devtools/check_unmarked_tests.sh b/devtools/check_unmarked_tests.sh index ff51bd1651..e031504e43 100755 --- a/devtools/check_unmarked_tests.sh +++ b/devtools/check_unmarked_tests.sh @@ -10,7 +10,7 @@ start_time=$(date +%s) echo "Files to check: $@" # Collect unmarked tests for the specific file and suppress errors -unmarked=$(pytest "$@" --collect-only -m "not unit and not regression" -q 2> /dev/null | head -n -2) +unmarked=$(pytest "$@" --collect-only -m "not unit and not regression and not benchmark" -q 2> /dev/null | head -n -2) # Count the number of unmarked tests found, ignoring empty lines num_unmarked=$(echo "$unmarked" | sed '/^\s*$/d' | wc -l) diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index a9e29be420..223c30fe67 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -411,16 +411,70 @@ def run(x): benchmark.pedantic(run, args=(x,), rounds=15, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_solve_fixed_iter_compiled(benchmark): + """Benchmark running eq.solve for fixed iteration count after compilation.""" + + def setup(): + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + return (eq,), {} + + def run(eq): + eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) + + benchmark.pedantic(run, setup=setup, rounds=5, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter(benchmark): """Benchmark running eq.solve for fixed iteration count.""" - jax.clear_caches() - eq = desc.examples.get("ESTELL") - with pytest.warns(UserWarning, match="Reducing radial"): - eq.change_resolution(6, 6, 6, 12, 12, 12) + + def setup(): + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + + return (eq,), {} def run(eq): + jax.clear_caches() eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1) + benchmark.pedantic(run, setup=setup, rounds=5, iterations=1) + + +@pytest.mark.slow +@pytest.mark.benchmark +def test_LinearConstraintProjection_build(benchmark): + """Benchmark LinearConstraintProjection build.""" + + def setup(): + jax.clear_caches() + eq = desc.examples.get("W7-X") + + obj = ObjectiveFunction(ForceBalance(eq)) + con = get_fixed_boundary_constraints(eq) + con = maybe_add_self_consistency(eq, con) + con = ObjectiveFunction(con) + obj.build() + con.build() + return (obj, con), {} + + def run(obj, con): + lc = LinearConstraintProjection(obj, con) + lc.build() + + benchmark.pedantic( + run, + setup=setup, + rounds=10, + iterations=1, + ) diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index 921a4e4451..623c4da51a 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -411,16 +411,70 @@ def run(x): benchmark.pedantic(run, args=(x,), rounds=15, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_solve_fixed_iter_compiled(benchmark): + """Benchmark running eq.solve for fixed iteration count after compilation.""" + + def setup(): + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + return (eq,), {} + + def run(eq): + eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) + + benchmark.pedantic(run, setup=setup, rounds=5, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter(benchmark): """Benchmark running eq.solve for fixed iteration count.""" - jax.clear_caches() - eq = desc.examples.get("ESTELL") - with pytest.warns(UserWarning, match="Reducing radial"): - eq.change_resolution(6, 6, 6, 12, 12, 12) + + def setup(): + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + + return (eq,), {} def run(eq): + jax.clear_caches() eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1) + benchmark.pedantic(run, setup=setup, rounds=5, iterations=1) + + +@pytest.mark.slow +@pytest.mark.benchmark +def test_LinearConstraintProjection_build(benchmark): + """Benchmark LinearConstraintProjection build.""" + + def setup(): + jax.clear_caches() + eq = desc.examples.get("W7-X") + + obj = ObjectiveFunction(ForceBalance(eq)) + con = get_fixed_boundary_constraints(eq) + con = maybe_add_self_consistency(eq, con) + con = ObjectiveFunction(con) + obj.build() + con.build() + return (obj, con), {} + + def run(obj, con): + lc = LinearConstraintProjection(obj, con) + lc.build() + + benchmark.pedantic( + run, + setup=setup, + rounds=10, + iterations=1, + )