Skip to content

Commit

Permalink
Merge branch 'master' into ku/fourier_bounce_neo
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Dec 19, 2024
2 parents e575712 + 34dbac0 commit 432933b
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 97 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ New Features
- Adds an option ``scaled_termination`` (defaults to True) to all the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``.
- ``desc.objectives.Omnigenity`` is now vectorized and able to optimize multiple surfaces at the same time. Previously it was required to use a different objective for each surface.
- Adds a new objective ``desc.objectives.MirrorRatio`` for targeting a particular mirror ratio on each flux surface, for either an ``Equilibrium`` or ``OmnigenousField``.
- Adds the output quantities ``wb`` and ``wp`` to ``VMECIO.save``.

Bug Fixes

Expand Down
2 changes: 2 additions & 0 deletions desc/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import desc.io
from desc.backend import execute_on_cpu


def listall():
Expand All @@ -13,6 +14,7 @@ def listall():
return names_stripped


@execute_on_cpu
def get(name, data=None):
"""Get example equilibria and data.
Expand Down
118 changes: 48 additions & 70 deletions desc/optimize/tr_subproblems.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,11 @@ def phi_and_derivative(alpha, suf, s, trust_radius):
"""
denom = s**2 + alpha
denom = jnp.where(denom == 0, 1, denom)
p_norm = jnp.linalg.norm(suf / denom)
p = -v.dot(suf / denom)
p_norm = jnp.linalg.norm(p)
phi = p_norm - trust_radius
phi_prime = -jnp.sum(suf**2 / denom**3) / p_norm
return phi, phi_prime
return p, phi, phi_prime

# Check if J has full rank and try Gauss-Newton step.
threshold = setdefault(threshold, jnp.finfo(s.dtype).eps * f.size)
Expand All @@ -222,43 +223,34 @@ def truefun(*_):
return p_newton, False, 0.0

def falsefun(*_):

alpha_upper = jnp.linalg.norm(suf) / trust_radius
alpha_lower = 0.0
alpha = setdefault(
initial_alpha,
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
)
alpha = setdefault(initial_alpha, 0.0)
alpha = jnp.clip(alpha, alpha_lower, alpha_upper)

phi, phi_prime = phi_and_derivative(alpha, suf, s, trust_radius)
_, phi, phi_prime = phi_and_derivative(alpha, suf, s, trust_radius)
k = 0

def loop_cond(state):
alpha, alpha_lower, alpha_upper, phi, k = state
p, alpha, alpha_lower, alpha_upper, phi, k = state
return (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter)

def loop_body(state):
alpha, alpha_lower, alpha_upper, phi, k = state
alpha = jnp.where(
(alpha < alpha_lower) | (alpha > alpha_upper),
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
alpha,
)

phi, phi_prime = phi_and_derivative(alpha, suf, s, trust_radius)
p, alpha, alpha_lower, alpha_upper, phi, k = state

p, phi, phi_prime = phi_and_derivative(alpha, suf, s, trust_radius)
alpha_upper = jnp.where(phi < 0, alpha, alpha_upper)
ratio = phi / phi_prime
alpha_lower = jnp.maximum(alpha_lower, alpha - ratio)
alpha -= (phi + trust_radius) * ratio / trust_radius
alpha = jnp.clip(alpha, alpha_lower, alpha_upper)
k += 1
return alpha, alpha_lower, alpha_upper, phi, k
return p, alpha, alpha_lower, alpha_upper, phi, k

alpha, *_ = while_loop(
loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, phi, k)
p, alpha, *_ = while_loop(
loop_cond, loop_body, (p_newton, alpha, alpha_lower, alpha_upper, phi, k)
)

p = -v.dot(suf / (s**2 + alpha))

# Make the norm of p equal to trust_radius; p is changed only slightly.
# This is done to prevent p from lying outside the trust region
# (which can cause problems later).
Expand Down Expand Up @@ -318,25 +310,19 @@ def truefun(*_):
def falsefun(*_):
alpha_upper = jnp.linalg.norm(g) / trust_radius
alpha_lower = 0.0
alpha = setdefault(
initial_alpha,
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
)
# the final alpha value is very small. So, starting from 0
# is faster for root finding
alpha = setdefault(initial_alpha, alpha_lower)
alpha = jnp.clip(alpha, alpha_lower, alpha_upper)
k = 0
# algorithm 4.3 from Nocedal & Wright

def loop_cond(state):
alpha, alpha_lower, alpha_upper, phi, k = state
p, alpha, alpha_lower, alpha_upper, phi, k = state
return (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter)

def loop_body(state):
alpha, alpha_lower, alpha_upper, phi, k = state

alpha = jnp.where(
(alpha < alpha_lower) | (alpha > alpha_upper),
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
alpha,
)
p, alpha, alpha_lower, alpha_upper, phi, k = state

Bi = B + alpha * jnp.eye(B.shape[0])
R = chol(Bi)
Expand All @@ -350,21 +336,16 @@ def loop_body(state):
q_norm = jnp.linalg.norm(q)

alpha += (p_norm / q_norm) ** 2 * phi / trust_radius
alpha = jnp.where(
(alpha < alpha_lower) | (alpha > alpha_upper),
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
alpha,
)
alpha = jnp.clip(alpha, alpha_lower, alpha_upper)

k += 1
return alpha, alpha_lower, alpha_upper, phi, k
return p, alpha, alpha_lower, alpha_upper, phi, k

alpha, *_ = while_loop(
loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k)
p, alpha, *_ = while_loop(
loop_cond,
loop_body,
(p_newton, alpha, alpha_lower, alpha_upper, jnp.inf, k),
)
Bi = B + alpha * jnp.eye(B.shape[0])
R = chol(Bi)
p = cho_solve((R, True), -g)

# Make the norm of p equal to trust_radius; p is changed only slightly.
# This is done to prevent p from lying outside the trust region
Expand All @@ -385,6 +366,15 @@ def trust_region_step_exact_qr(
Solves problems of the form
min_p ||J*p + f||^2, ||p|| < trust_radius
Introduces a Levenberg-Marquardt parameter alpha to make the problem
well-conditioned.
min_p ||J*p + f||^2 + alpha*||p||^2, ||p|| < trust_radius
The objective function can be written as
p.T(J.T@J + alpha*I)p + 2f.TJp + f.Tf
which is equivalent to
|| [J; sqrt(alpha)*I].Tp - [f; 0].T ||^2
Parameters
----------
f : ndarray
Expand Down Expand Up @@ -421,26 +411,20 @@ def truefun(*_):
def falsefun(*_):
alpha_upper = jnp.linalg.norm(J.T @ f) / trust_radius
alpha_lower = 0.0
alpha = setdefault(
initial_alpha,
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
)
# the final alpha value is very small. So, starting from 0
# is faster for root finding
alpha = setdefault(initial_alpha, alpha_lower)
alpha = jnp.clip(alpha, alpha_lower, alpha_upper)
k = 0
# algorithm 4.3 from Nocedal & Wright

fp = jnp.pad(f, (0, J.shape[1]))

def loop_cond(state):
alpha, alpha_lower, alpha_upper, phi, k = state
p, alpha, alpha_lower, alpha_upper, phi, k = state
return (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter)

def loop_body(state):
alpha, alpha_lower, alpha_upper, phi, k = state

alpha = jnp.where(
(alpha < alpha_lower) | (alpha > alpha_upper),
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
alpha,
)
p, alpha, alpha_lower, alpha_upper, phi, k = state

Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
# Ji is always tall since its padded by alpha*I
Expand All @@ -456,22 +440,16 @@ def loop_body(state):
q_norm = jnp.linalg.norm(q)

alpha += (p_norm / q_norm) ** 2 * phi / trust_radius
alpha = jnp.where(
(alpha < alpha_lower) | (alpha > alpha_upper),
jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
alpha,
)
alpha = jnp.clip(alpha, alpha_lower, alpha_upper)
k += 1
return alpha, alpha_lower, alpha_upper, phi, k
return p, alpha, alpha_lower, alpha_upper, phi, k

alpha, *_ = while_loop(
loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k)
p, alpha, *_ = while_loop(
loop_cond,
loop_body,
(p_newton, alpha, alpha_lower, alpha_upper, jnp.inf, k),
)

Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
Q, R = qr(Ji, mode="economic")
p = solve_triangular(R, -Q.T @ fp)

# Make the norm of p equal to trust_radius; p is changed only slightly.
# This is done to prevent p from lying outside the trust region
# (which can cause problems later).
Expand Down
29 changes: 20 additions & 9 deletions desc/vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,16 @@ def save(cls, eq, path, surfs=128, verbose=1, M_nyq=None, N_nyq=None): # noqa:
grid_full = LinearGrid(M=M_nyq, N=N_nyq, NFP=NFP, rho=r_full)

data_quad = eq.compute(
["R0/a", "V", "<|B|>_rms", "<beta>_vol", "<beta_pol>_vol", "<beta_tor>_vol"]
[
"R0/a",
"V",
"W_B",
"W_p",
"<|B|>_rms",
"<beta>_vol",
"<beta_pol>_vol",
"<beta_tor>_vol",
]
)
data_axis = eq.compute(["G", "p", "R", "<|B|^2>", "<|B|>"], grid=grid_axis)
data_lcfs = eq.compute(["G", "I", "R", "Z"], grid=grid_lcfs)
Expand Down Expand Up @@ -502,6 +511,16 @@ def save(cls, eq, path, surfs=128, verbose=1, M_nyq=None, N_nyq=None): # noqa:
betator.units = "None"
betator[:] = data_quad["<beta_tor>_vol"]

wb = file.createVariable("wb", np.float64)
wb.long_name = "plasma magnetic energy * mu_0/(4*pi^2)"
wb.units = "T^2*m^3"
wb[:] = data_quad["W_B"] * mu_0 / (4 * np.pi**2)

wp = file.createVariable("wp", np.float64)
wp.long_name = "plasma thermodynamic energy * mu_0/(4*pi^2)"
wp.units = "T^2*m^3"
wp[:] = np.abs(data_quad["W_p"]) * mu_0 / (4 * np.pi**2)

# scalars computed at the magnetic axis

rbtor0 = file.createVariable("rbtor0", np.float64)
Expand Down Expand Up @@ -1338,16 +1357,8 @@ def fullfit(x):
specw = file.createVariable("specw", np.float64, ("radius",))
specw[:] = np.zeros((file.dimensions["radius"].size,))
# this is not the same as DESC's "W_B"
wb = file.createVariable("wb", np.float64)
wb[:] = 0.0
wdot = file.createVariable("wdot", np.float64, ("time",))
wdot[:] = np.zeros((file.dimensions["time"].size,))
# this is not the same as DESC's "W_p"
wp = file.createVariable("wp", np.float64)
wp[:] = 0.0
"""

file.close()
Expand Down
23 changes: 6 additions & 17 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,18 @@ specific JAX GPU installation instructions, as that is the main installation dif

**Note that DESC does not always test on or guarantee support of the latest version of JAX (which does not have a stable 1.0 release yet), and thus older versions of GPU-accelerated versions of JAX may need to be installed, which may in turn require lower versions of JaxLib, as well as CUDA and CuDNN.**


Perlmutter (NERSC)
++++++++++++++++++++++++++++++
These instructions were tested and confirmed to work on the Perlmutter supercomputer at NERSC on June 18, 2024.
These instructions were tested and confirmed to work on the Perlmutter supercomputer at NERSC on December 17, 2024.

Set up the correct cuda environment for jax installation

.. code-block:: sh
module load cudatoolkit/12.2
module load cudatoolkit/12.4
module load cudnn/8.9.3_cuda12
module load python
module load python/3.11
Check that you have loaded these modules

Expand All @@ -118,21 +119,9 @@ Create a conda environment for DESC (`following these instructions <https://docs

.. code-block:: sh
conda create -n desc-env python=3.9
conda create -n desc-env python=3.11
conda activate desc-env
pip install --no-cache-dir "jax==0.4.23" "jaxlib[cuda12_cudnn89]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For Perlmutter installation, please change the scipy version from

.. code-block:: sh
scipy >= 1.7.0, < 2.0.0
to

.. code-block:: sh
scipy >= 1.7.0, <= 1.11.3
pip install --upgrade "jax[cuda12]"
Clone and install DESC

Expand Down
3 changes: 2 additions & 1 deletion tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,7 +1416,8 @@ def test_optimize_coil_currents(DummyCoilSet):
verbose=2,
copy=True,
)
# check that optimized coil currents changed by more than 15% from initial values
# check that on average optimized coil currents changed by more than
# 15% from initial values
np.testing.assert_array_less(
np.asarray(coils.current).mean() * 0.15,
np.abs(np.asarray(coils_opt.current) - np.asarray(coils.current)).mean(),
Expand Down
4 changes: 4 additions & 0 deletions tests/test_vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ def test_vmec_save_1(VMEC_save):
np.testing.assert_allclose(
vmec.variables["betator"][:], desc.variables["betator"][:], rtol=1e-5
)
np.testing.assert_allclose(vmec.variables["wb"][:], desc.variables["wb"][:])
np.testing.assert_allclose(
vmec.variables["wp"][:], desc.variables["wp"][:], rtol=1e-6
)
np.testing.assert_allclose(
vmec.variables["ctor"][:], desc.variables["ctor"][:], rtol=1e-5
)
Expand Down

0 comments on commit 432933b

Please sign in to comment.