diff --git a/CHANGELOG.md b/CHANGELOG.md index 88bcdc4cf..19d182520 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ New Features - Adds an option ``scaled_termination`` (defaults to True) to all the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``. - ``desc.objectives.Omnigenity`` is now vectorized and able to optimize multiple surfaces at the same time. Previously it was required to use a different objective for each surface. - Adds a new objective ``desc.objectives.MirrorRatio`` for targeting a particular mirror ratio on each flux surface, for either an ``Equilibrium`` or ``OmnigenousField``. +- Adds the output quantities ``wb`` and ``wp`` to ``VMECIO.save``. Bug Fixes diff --git a/desc/examples/__init__.py b/desc/examples/__init__.py index e8eeb5966..8c0f253ad 100644 --- a/desc/examples/__init__.py +++ b/desc/examples/__init__.py @@ -3,6 +3,7 @@ import os import desc.io +from desc.backend import execute_on_cpu def listall(): @@ -13,6 +14,7 @@ def listall(): return names_stripped +@execute_on_cpu def get(name, data=None): """Get example equilibria and data. diff --git a/desc/optimize/tr_subproblems.py b/desc/optimize/tr_subproblems.py index 8c39e8229..4cc8e58fb 100644 --- a/desc/optimize/tr_subproblems.py +++ b/desc/optimize/tr_subproblems.py @@ -205,10 +205,11 @@ def phi_and_derivative(alpha, suf, s, trust_radius): """ denom = s**2 + alpha denom = jnp.where(denom == 0, 1, denom) - p_norm = jnp.linalg.norm(suf / denom) + p = -v.dot(suf / denom) + p_norm = jnp.linalg.norm(p) phi = p_norm - trust_radius phi_prime = -jnp.sum(suf**2 / denom**3) / p_norm - return phi, phi_prime + return p, phi, phi_prime # Check if J has full rank and try Gauss-Newton step. threshold = setdefault(threshold, jnp.finfo(s.dtype).eps * f.size) @@ -222,43 +223,34 @@ def truefun(*_): return p_newton, False, 0.0 def falsefun(*_): - alpha_upper = jnp.linalg.norm(suf) / trust_radius alpha_lower = 0.0 - alpha = setdefault( - initial_alpha, - jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), - ) + alpha = setdefault(initial_alpha, 0.0) + alpha = jnp.clip(alpha, alpha_lower, alpha_upper) - phi, phi_prime = phi_and_derivative(alpha, suf, s, trust_radius) + _, phi, phi_prime = phi_and_derivative(alpha, suf, s, trust_radius) k = 0 def loop_cond(state): - alpha, alpha_lower, alpha_upper, phi, k = state + p, alpha, alpha_lower, alpha_upper, phi, k = state return (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter) def loop_body(state): - alpha, alpha_lower, alpha_upper, phi, k = state - alpha = jnp.where( - (alpha < alpha_lower) | (alpha > alpha_upper), - jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), - alpha, - ) - - phi, phi_prime = phi_and_derivative(alpha, suf, s, trust_radius) + p, alpha, alpha_lower, alpha_upper, phi, k = state + + p, phi, phi_prime = phi_and_derivative(alpha, suf, s, trust_radius) alpha_upper = jnp.where(phi < 0, alpha, alpha_upper) ratio = phi / phi_prime alpha_lower = jnp.maximum(alpha_lower, alpha - ratio) alpha -= (phi + trust_radius) * ratio / trust_radius + alpha = jnp.clip(alpha, alpha_lower, alpha_upper) k += 1 - return alpha, alpha_lower, alpha_upper, phi, k + return p, alpha, alpha_lower, alpha_upper, phi, k - alpha, *_ = while_loop( - loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, phi, k) + p, alpha, *_ = while_loop( + loop_cond, loop_body, (p_newton, alpha, alpha_lower, alpha_upper, phi, k) ) - p = -v.dot(suf / (s**2 + alpha)) - # Make the norm of p equal to trust_radius; p is changed only slightly. # This is done to prevent p from lying outside the trust region # (which can cause problems later). @@ -318,25 +310,19 @@ def truefun(*_): def falsefun(*_): alpha_upper = jnp.linalg.norm(g) / trust_radius alpha_lower = 0.0 - alpha = setdefault( - initial_alpha, - jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), - ) + # the final alpha value is very small. So, starting from 0 + # is faster for root finding + alpha = setdefault(initial_alpha, alpha_lower) + alpha = jnp.clip(alpha, alpha_lower, alpha_upper) k = 0 # algorithm 4.3 from Nocedal & Wright def loop_cond(state): - alpha, alpha_lower, alpha_upper, phi, k = state + p, alpha, alpha_lower, alpha_upper, phi, k = state return (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter) def loop_body(state): - alpha, alpha_lower, alpha_upper, phi, k = state - - alpha = jnp.where( - (alpha < alpha_lower) | (alpha > alpha_upper), - jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), - alpha, - ) + p, alpha, alpha_lower, alpha_upper, phi, k = state Bi = B + alpha * jnp.eye(B.shape[0]) R = chol(Bi) @@ -350,21 +336,16 @@ def loop_body(state): q_norm = jnp.linalg.norm(q) alpha += (p_norm / q_norm) ** 2 * phi / trust_radius - alpha = jnp.where( - (alpha < alpha_lower) | (alpha > alpha_upper), - jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), - alpha, - ) + alpha = jnp.clip(alpha, alpha_lower, alpha_upper) k += 1 - return alpha, alpha_lower, alpha_upper, phi, k + return p, alpha, alpha_lower, alpha_upper, phi, k - alpha, *_ = while_loop( - loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k) + p, alpha, *_ = while_loop( + loop_cond, + loop_body, + (p_newton, alpha, alpha_lower, alpha_upper, jnp.inf, k), ) - Bi = B + alpha * jnp.eye(B.shape[0]) - R = chol(Bi) - p = cho_solve((R, True), -g) # Make the norm of p equal to trust_radius; p is changed only slightly. # This is done to prevent p from lying outside the trust region @@ -385,6 +366,15 @@ def trust_region_step_exact_qr( Solves problems of the form min_p ||J*p + f||^2, ||p|| < trust_radius + Introduces a Levenberg-Marquardt parameter alpha to make the problem + well-conditioned. + min_p ||J*p + f||^2 + alpha*||p||^2, ||p|| < trust_radius + + The objective function can be written as + p.T(J.T@J + alpha*I)p + 2f.TJp + f.Tf + which is equivalent to + || [J; sqrt(alpha)*I].Tp - [f; 0].T ||^2 + Parameters ---------- f : ndarray @@ -421,26 +411,20 @@ def truefun(*_): def falsefun(*_): alpha_upper = jnp.linalg.norm(J.T @ f) / trust_radius alpha_lower = 0.0 - alpha = setdefault( - initial_alpha, - jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), - ) + # the final alpha value is very small. So, starting from 0 + # is faster for root finding + alpha = setdefault(initial_alpha, alpha_lower) + alpha = jnp.clip(alpha, alpha_lower, alpha_upper) k = 0 - # algorithm 4.3 from Nocedal & Wright + fp = jnp.pad(f, (0, J.shape[1])) def loop_cond(state): - alpha, alpha_lower, alpha_upper, phi, k = state + p, alpha, alpha_lower, alpha_upper, phi, k = state return (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter) def loop_body(state): - alpha, alpha_lower, alpha_upper, phi, k = state - - alpha = jnp.where( - (alpha < alpha_lower) | (alpha > alpha_upper), - jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), - alpha, - ) + p, alpha, alpha_lower, alpha_upper, phi, k = state Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) # Ji is always tall since its padded by alpha*I @@ -456,22 +440,16 @@ def loop_body(state): q_norm = jnp.linalg.norm(q) alpha += (p_norm / q_norm) ** 2 * phi / trust_radius - alpha = jnp.where( - (alpha < alpha_lower) | (alpha > alpha_upper), - jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5), - alpha, - ) + alpha = jnp.clip(alpha, alpha_lower, alpha_upper) k += 1 - return alpha, alpha_lower, alpha_upper, phi, k + return p, alpha, alpha_lower, alpha_upper, phi, k - alpha, *_ = while_loop( - loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k) + p, alpha, *_ = while_loop( + loop_cond, + loop_body, + (p_newton, alpha, alpha_lower, alpha_upper, jnp.inf, k), ) - Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])]) - Q, R = qr(Ji, mode="economic") - p = solve_triangular(R, -Q.T @ fp) - # Make the norm of p equal to trust_radius; p is changed only slightly. # This is done to prevent p from lying outside the trust region # (which can cause problems later). diff --git a/desc/vmec.py b/desc/vmec.py index 14c4cafff..c6743c8d2 100644 --- a/desc/vmec.py +++ b/desc/vmec.py @@ -289,7 +289,16 @@ def save(cls, eq, path, surfs=128, verbose=1, M_nyq=None, N_nyq=None): # noqa: grid_full = LinearGrid(M=M_nyq, N=N_nyq, NFP=NFP, rho=r_full) data_quad = eq.compute( - ["R0/a", "V", "<|B|>_rms", "_vol", "_vol", "_vol"] + [ + "R0/a", + "V", + "W_B", + "W_p", + "<|B|>_rms", + "_vol", + "_vol", + "_vol", + ] ) data_axis = eq.compute(["G", "p", "R", "<|B|^2>", "<|B|>"], grid=grid_axis) data_lcfs = eq.compute(["G", "I", "R", "Z"], grid=grid_lcfs) @@ -502,6 +511,16 @@ def save(cls, eq, path, surfs=128, verbose=1, M_nyq=None, N_nyq=None): # noqa: betator.units = "None" betator[:] = data_quad["_vol"] + wb = file.createVariable("wb", np.float64) + wb.long_name = "plasma magnetic energy * mu_0/(4*pi^2)" + wb.units = "T^2*m^3" + wb[:] = data_quad["W_B"] * mu_0 / (4 * np.pi**2) + + wp = file.createVariable("wp", np.float64) + wp.long_name = "plasma thermodynamic energy * mu_0/(4*pi^2)" + wp.units = "T^2*m^3" + wp[:] = np.abs(data_quad["W_p"]) * mu_0 / (4 * np.pi**2) + # scalars computed at the magnetic axis rbtor0 = file.createVariable("rbtor0", np.float64) @@ -1338,16 +1357,8 @@ def fullfit(x): specw = file.createVariable("specw", np.float64, ("radius",)) specw[:] = np.zeros((file.dimensions["radius"].size,)) - # this is not the same as DESC's "W_B" - wb = file.createVariable("wb", np.float64) - wb[:] = 0.0 - wdot = file.createVariable("wdot", np.float64, ("time",)) wdot[:] = np.zeros((file.dimensions["time"].size,)) - - # this is not the same as DESC's "W_p" - wp = file.createVariable("wp", np.float64) - wp[:] = 0.0 """ file.close() diff --git a/docs/installation.rst b/docs/installation.rst index 9df67563b..403b1c658 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -96,17 +96,18 @@ specific JAX GPU installation instructions, as that is the main installation dif **Note that DESC does not always test on or guarantee support of the latest version of JAX (which does not have a stable 1.0 release yet), and thus older versions of GPU-accelerated versions of JAX may need to be installed, which may in turn require lower versions of JaxLib, as well as CUDA and CuDNN.** + Perlmutter (NERSC) ++++++++++++++++++++++++++++++ -These instructions were tested and confirmed to work on the Perlmutter supercomputer at NERSC on June 18, 2024. +These instructions were tested and confirmed to work on the Perlmutter supercomputer at NERSC on December 17, 2024. Set up the correct cuda environment for jax installation .. code-block:: sh - module load cudatoolkit/12.2 + module load cudatoolkit/12.4 module load cudnn/8.9.3_cuda12 - module load python + module load python/3.11 Check that you have loaded these modules @@ -118,21 +119,9 @@ Create a conda environment for DESC (`following these instructions = 1.7.0, < 2.0.0 - -to - -.. code-block:: sh - - scipy >= 1.7.0, <= 1.11.3 + pip install --upgrade "jax[cuda12]" Clone and install DESC diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 637d247a8..82e9130ee 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1416,7 +1416,8 @@ def test_optimize_coil_currents(DummyCoilSet): verbose=2, copy=True, ) - # check that optimized coil currents changed by more than 15% from initial values + # check that on average optimized coil currents changed by more than + # 15% from initial values np.testing.assert_array_less( np.asarray(coils.current).mean() * 0.15, np.abs(np.asarray(coils_opt.current) - np.asarray(coils.current)).mean(), diff --git a/tests/test_vmec.py b/tests/test_vmec.py index 91b34667a..2edca5c28 100644 --- a/tests/test_vmec.py +++ b/tests/test_vmec.py @@ -496,6 +496,10 @@ def test_vmec_save_1(VMEC_save): np.testing.assert_allclose( vmec.variables["betator"][:], desc.variables["betator"][:], rtol=1e-5 ) + np.testing.assert_allclose(vmec.variables["wb"][:], desc.variables["wb"][:]) + np.testing.assert_allclose( + vmec.variables["wp"][:], desc.variables["wp"][:], rtol=1e-6 + ) np.testing.assert_allclose( vmec.variables["ctor"][:], desc.variables["ctor"][:], rtol=1e-5 )