Skip to content

Commit

Permalink
Merge branch 'master' into dd/external
Browse files Browse the repository at this point in the history
  • Loading branch information
ddudt authored Dec 18, 2024
2 parents 24dd2f3 + 93d2564 commit d3aa2dd
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ New Feature
- Adds an option ``scaled_termination`` (defaults to True) to all of the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``.
- ``desc.objectives.Omnigenity`` is now vectorized and able to optimize multiple surfaces at the same time. Previously it was required to use a different objective for each surface.
- Adds a new objective ``desc.objectives.MirrorRatio`` for targeting a particular mirror ratio on each flux surface, for either an ``Equilibrium`` or ``OmnigenousField``.
- Adds the output quantities ``wb`` and ``wp`` to ``VMECIO.save``.
- Adds a new objective ``desc.objectives.ExternalObjective`` for wrapping external codes with finite differences.

Bug Fixes
Expand Down
29 changes: 20 additions & 9 deletions desc/vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,16 @@ def save(cls, eq, path, surfs=128, verbose=1, M_nyq=None, N_nyq=None): # noqa:
grid_full = LinearGrid(M=M_nyq, N=N_nyq, NFP=NFP, rho=r_full)

data_quad = eq.compute(
["R0/a", "V", "<|B|>_rms", "<beta>_vol", "<beta_pol>_vol", "<beta_tor>_vol"]
[
"R0/a",
"V",
"W_B",
"W_p",
"<|B|>_rms",
"<beta>_vol",
"<beta_pol>_vol",
"<beta_tor>_vol",
]
)
data_axis = eq.compute(["G", "p", "R", "<|B|^2>", "<|B|>"], grid=grid_axis)
data_lcfs = eq.compute(["G", "I", "R", "Z"], grid=grid_lcfs)
Expand Down Expand Up @@ -502,6 +511,16 @@ def save(cls, eq, path, surfs=128, verbose=1, M_nyq=None, N_nyq=None): # noqa:
betator.units = "None"
betator[:] = data_quad["<beta_tor>_vol"]

wb = file.createVariable("wb", np.float64)
wb.long_name = "plasma magnetic energy * mu_0/(4*pi^2)"
wb.units = "T^2*m^3"
wb[:] = data_quad["W_B"] * mu_0 / (4 * np.pi**2)

wp = file.createVariable("wp", np.float64)
wp.long_name = "plasma thermodynamic energy * mu_0/(4*pi^2)"
wp.units = "T^2*m^3"
wp[:] = np.abs(data_quad["W_p"]) * mu_0 / (4 * np.pi**2)

# scalars computed at the magnetic axis

rbtor0 = file.createVariable("rbtor0", np.float64)
Expand Down Expand Up @@ -1338,16 +1357,8 @@ def fullfit(x):
specw = file.createVariable("specw", np.float64, ("radius",))
specw[:] = np.zeros((file.dimensions["radius"].size,))
# this is not the same as DESC's "W_B"
wb = file.createVariable("wb", np.float64)
wb[:] = 0.0
wdot = file.createVariable("wdot", np.float64, ("time",))
wdot[:] = np.zeros((file.dimensions["time"].size,))
# this is not the same as DESC's "W_p"
wp = file.createVariable("wp", np.float64)
wp[:] = 0.0
"""

file.close()
Expand Down
23 changes: 6 additions & 17 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,18 @@ specific JAX GPU installation instructions, as that is the main installation dif

**Note that DESC does not always test on or guarantee support of the latest version of JAX (which does not have a stable 1.0 release yet), and thus older versions of GPU-accelerated versions of JAX may need to be installed, which may in turn require lower versions of JaxLib, as well as CUDA and CuDNN.**


Perlmutter (NERSC)
++++++++++++++++++++++++++++++
These instructions were tested and confirmed to work on the Perlmutter supercomputer at NERSC on June 18, 2024.
These instructions were tested and confirmed to work on the Perlmutter supercomputer at NERSC on December 17, 2024.

Set up the correct cuda environment for jax installation

.. code-block:: sh
module load cudatoolkit/12.2
module load cudatoolkit/12.4
module load cudnn/8.9.3_cuda12
module load python
module load python/3.11
Check that you have loaded these modules

Expand All @@ -118,21 +119,9 @@ Create a conda environment for DESC (`following these instructions <https://docs

.. code-block:: sh
conda create -n desc-env python=3.9
conda create -n desc-env python=3.11
conda activate desc-env
pip install --no-cache-dir "jax==0.4.23" "jaxlib[cuda12_cudnn89]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For Perlmutter installation, please change the scipy version from

.. code-block:: sh
scipy >= 1.7.0, < 2.0.0
to

.. code-block:: sh
scipy >= 1.7.0, <= 1.11.3
pip install --upgrade "jax[cuda12]"
Clone and install DESC

Expand Down
4 changes: 4 additions & 0 deletions tests/test_vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ def test_vmec_save_1(VMEC_save):
np.testing.assert_allclose(
vmec.variables["betator"][:], desc.variables["betator"][:], rtol=1e-5
)
np.testing.assert_allclose(vmec.variables["wb"][:], desc.variables["wb"][:])
np.testing.assert_allclose(
vmec.variables["wp"][:], desc.variables["wp"][:], rtol=1e-6
)
np.testing.assert_allclose(
vmec.variables["ctor"][:], desc.variables["ctor"][:], rtol=1e-5
)
Expand Down

0 comments on commit d3aa2dd

Please sign in to comment.