Skip to content

Commit

Permalink
Merge branch 'master' into rg/adjoint_ballooning
Browse files Browse the repository at this point in the history
  • Loading branch information
rahulgaur104 authored Aug 24, 2024
2 parents 70f921f + 4281a96 commit 6326a4f
Show file tree
Hide file tree
Showing 33 changed files with 232,455 additions and 2,038,642 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ repos:
- repo: https://github.com/psf/black
rev: 24.3.0
hooks:
- id: black
- id: black-jupyter
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion desc/objectives/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BootstrapRedlConsistency(_Objective):

_coordinates = "r"
_units = "(T A m^-2)"
_print_value_fmt = "Bootstrap current self-consistency error: {:10.3e} "
_print_value_fmt = "Bootstrap current self-consistency error: "

def __init__(
self,
Expand Down
16 changes: 8 additions & 8 deletions desc/objectives/_coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class CoilLength(_CoilObjective):

_scalar = False # Not always a scalar, if a coilset is passed in
_units = "(m)"
_print_value_fmt = "Coil length: {:10.3e} "
_print_value_fmt = "Coil length: "

def __init__(
self,
Expand Down Expand Up @@ -379,7 +379,7 @@ class CoilCurvature(_CoilObjective):

_scalar = False
_units = "(m^-1)"
_print_value_fmt = "Coil curvature: {:10.3e} "
_print_value_fmt = "Coil curvature: "

def __init__(
self,
Expand Down Expand Up @@ -499,7 +499,7 @@ class CoilTorsion(_CoilObjective):

_scalar = False
_units = "(m^-1)"
_print_value_fmt = "Coil torsion: {:10.3e} "
_print_value_fmt = "Coil torsion: "

def __init__(
self,
Expand Down Expand Up @@ -619,7 +619,7 @@ class CoilCurrentLength(CoilLength):

_scalar = False
_units = "(A*m)"
_print_value_fmt = "Coil current length: {:10.3e} "
_print_value_fmt = "Coil current length: "

def __init__(
self,
Expand Down Expand Up @@ -747,7 +747,7 @@ class CoilSetMinDistance(_Objective):

_scalar = False
_units = "(m)"
_print_value_fmt = "Minimum coil-coil distance: {:10.3e} "
_print_value_fmt = "Minimum coil-coil distance: "

def __init__(
self,
Expand Down Expand Up @@ -921,7 +921,7 @@ class PlasmaCoilSetMinDistance(_Objective):

_scalar = False
_units = "(m)"
_print_value_fmt = "Minimum plasma-coil distance: {:10.3e} "
_print_value_fmt = "Minimum plasma-coil distance: "

def __init__(
self,
Expand Down Expand Up @@ -1151,7 +1151,7 @@ class QuadraticFlux(_Objective):

_scalar = False
_linear = False
_print_value_fmt = "Boundary normal field error: {:10.3e} "
_print_value_fmt = "Boundary normal field error: "
_units = "(T m^2)"
_coordinates = "rtz"

Expand Down Expand Up @@ -1353,7 +1353,7 @@ class ToroidalFlux(_Objective):

_coordinates = "rtz"
_units = "(Wb)"
_print_value_fmt = "Toroidal Flux: {:10.3e} "
_print_value_fmt = "Toroidal Flux: "

def __init__(
self,
Expand Down
10 changes: 5 additions & 5 deletions desc/objectives/_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ForceBalance(_Objective):
_equilibrium = True
_coordinates = "rtz"
_units = "(N)"
_print_value_fmt = "Force error: {:10.3e} "
_print_value_fmt = "Force error: "

def __init__(
self,
Expand Down Expand Up @@ -241,7 +241,7 @@ class ForceBalanceAnisotropic(_Objective):
_units = "(N)"
_coordinates = "rtz"
_equilibrium = True
_print_value_fmt = "Anisotropic force error: {:10.3e} "
_print_value_fmt = "Anisotropic force error: "

def __init__(
self,
Expand Down Expand Up @@ -399,7 +399,7 @@ class RadialForceBalance(_Objective):
_equilibrium = True
_coordinates = "rtz"
_units = "(N)"
_print_value_fmt = "Radial force error: {:10.3e} "
_print_value_fmt = "Radial force error: "

def __init__(
self,
Expand Down Expand Up @@ -714,7 +714,7 @@ class Energy(_Objective):
_coordinates = ""
_equilibrium = True
_units = "(J)"
_print_value_fmt = "Total MHD energy: {:10.3e} "
_print_value_fmt = "Total MHD energy: "
_io_attrs_ = _Objective._io_attrs_ + ["gamma"]

def __init__(
Expand Down Expand Up @@ -880,7 +880,7 @@ class CurrentDensity(_Objective):
_equilibrium = True
_coordinates = "rtz"
_units = "(A*m)"
_print_value_fmt = "Current density: {:10.3e} "
_print_value_fmt = "Current density: "

def __init__(
self,
Expand Down
100 changes: 72 additions & 28 deletions desc/objectives/_free_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from desc.integrals import DFTInterpolator, FFTInterpolator, virtual_casing_biot_savart
from desc.nestor import Nestor
from desc.objectives.objective_funs import _Objective
from desc.utils import Timer, errorif, warnif
from desc.utils import PRINT_WIDTH, Timer, errorif, warnif

from .normalization import compute_scaling_factors

Expand Down Expand Up @@ -78,7 +78,7 @@ class VacuumBoundaryError(_Objective):

_scalar = False
_linear = False
_print_value_fmt = "Boundary Error: {:10.3e} "
_print_value_fmt = "Boundary Error: "
_units = "(T*m^2, T^2*m^2)"
_coordinates = "rtz"

Expand Down Expand Up @@ -247,11 +247,12 @@ def compute(self, eq_params, field_params=None, constants=None):
Bsq_err = (bsq_in - bsq_out) * g
return jnp.concatenate([Bn_err, Bsq_err])

def print_value(self, *args, **kwargs):
def print_value(self, args, args0=None, **kwargs):
"""Print the value of the objective."""
# this objective is really 2 residuals concatenated so its helpful to print
# them individually
f = self.compute_unscaled(*args, **kwargs)
f0 = self.compute_unscaled(*args0, **kwargs) if args0 is not None else f
# try to do weighted mean if possible
constants = kwargs.get("constants", self.constants)
if constants is None:
Expand All @@ -260,56 +261,77 @@ def print_value(self, *args, **kwargs):
w = constants["quad_weights"]

abserr = jnp.all(self.target == 0)
pre_width = len("Maximum absolute ") if abserr else len("Maximum ")

def _print(fmt, fmax, fmin, fmean, norm, units):
def _print(fmt, fmax, fmin, fmean, f0max, f0min, f0mean, norm, units):

print(
"Maximum " + ("absolute " if abserr else "") + fmt.format(fmax) + units
"Maximum "
+ ("absolute " if abserr else "")
+ fmt.format(f0max, fmax)
+ units
)
print(
"Minimum " + ("absolute " if abserr else "") + fmt.format(fmin) + units
"Minimum "
+ ("absolute " if abserr else "")
+ fmt.format(f0min, fmin)
+ units
)
print(
"Average " + ("absolute " if abserr else "") + fmt.format(fmean) + units
"Average "
+ ("absolute " if abserr else "")
+ fmt.format(f0mean, fmean)
+ units
)

if self._normalize and units != "(dimensionless)":
print(
"Maximum "
+ ("absolute " if abserr else "")
+ fmt.format(fmax / norm)
+ fmt.format(f0max / norm, fmax / norm)
+ "(normalized)"
)
print(
"Minimum "
+ ("absolute " if abserr else "")
+ fmt.format(fmin / norm)
+ fmt.format(f0min / norm, fmin / norm)
+ "(normalized)"
)
print(
"Average "
+ ("absolute " if abserr else "")
+ fmt.format(fmean / norm)
+ fmt.format(f0mean / norm, fmean / norm)
+ "(normalized)"
)

formats = [
"Boundary normal field error: {:10.3e} ",
"Boundary magnetic pressure error: {:10.3e} ",
"Boundary normal field error: ",
"Boundary magnetic pressure error: ",
]
units = ["(T*m^2)", "(T^2*m^2)"]
nn = f.size // 2
norms = [self.normalization[0], self.normalization[nn]]
for i, (fmt, norm, unit) in enumerate(zip(formats, norms, units)):
for i, (fmt, norm, units) in enumerate(zip(formats, norms, units)):
fi = f[i * nn : (i + 1) * nn]
f0i = f0[i * nn : (i + 1) * nn]
# target == 0 probably indicates f is some sort of error metric,
# mean abs makes more sense than mean
fi = jnp.abs(fi) if abserr else fi
f0i = jnp.abs(f0i) if abserr else f0i
wi = w[i * nn : (i + 1) * nn]
fmax = jnp.max(fi)
fmin = jnp.min(fi)
fmean = jnp.mean(fi * wi) / jnp.mean(wi)
_print(fmt, fmax, fmin, fmean, norm, unit)

f0max = jnp.max(f0i)
f0min = jnp.min(f0i)
f0mean = jnp.mean(f0i * wi) / jnp.mean(wi)
fmt = (
f"{fmt:<{PRINT_WIDTH-pre_width}}" + "{:10.3e} --> {:10.3e} "
if args0 is not None
else f"{fmt:<{PRINT_WIDTH-pre_width}}" + "{:10.3e} "
)
_print(fmt, fmax, fmin, fmean, f0max, f0min, f0mean, norm, units)


class BoundaryError(_Objective):
Expand Down Expand Up @@ -409,7 +431,7 @@ class BoundaryError(_Objective):

_scalar = False
_linear = False
_print_value_fmt = "Boundary Error: {:10.3e} "
_print_value_fmt = "Boundary Error: "
_units = "(T*m^2, T^2*m^2, T*m^2)"

_coordinates = "rtz"
Expand Down Expand Up @@ -695,11 +717,12 @@ def compute(self, eq_params, field_params=None, constants=None):
else:
return jnp.concatenate([Bn_err, Bsq_err])

def print_value(self, *args, **kwargs):
def print_value(self, args, args0=None, **kwargs):
"""Print the value of the objective."""
# this objective is really 3 residuals concatenated so its helpful to print
# them individually
f = self.compute_unscaled(*args, **kwargs)
f0 = self.compute_unscaled(*args0, **kwargs) if args0 is not None else f
# try to do weighted mean if possible
constants = kwargs.get("constants", self.constants)
if constants is None:
Expand All @@ -708,43 +731,53 @@ def print_value(self, *args, **kwargs):
w = constants["quad_weights"]

abserr = jnp.all(self.target == 0)
pre_width = len("Maximum absolute ") if abserr else len("Maximum ")

def _print(fmt, fmax, fmin, fmean, norm, units):
def _print(fmt, fmax, fmin, fmean, f0max, f0min, f0mean, norm, unit):

print(
"Maximum " + ("absolute " if abserr else "") + fmt.format(fmax) + units
"Maximum "
+ ("absolute " if abserr else "")
+ fmt.format(f0max, fmax)
+ unit
)
print(
"Minimum " + ("absolute " if abserr else "") + fmt.format(fmin) + units
"Minimum "
+ ("absolute " if abserr else "")
+ fmt.format(f0min, fmin)
+ unit
)
print(
"Average " + ("absolute " if abserr else "") + fmt.format(fmean) + units
"Average "
+ ("absolute " if abserr else "")
+ fmt.format(f0mean, fmean)
+ unit
)

if self._normalize and units != "(dimensionless)":
print(
"Maximum "
+ ("absolute " if abserr else "")
+ fmt.format(fmax / norm)
+ fmt.format(f0max / norm, fmax / norm)
+ "(normalized)"
)
print(
"Minimum "
+ ("absolute " if abserr else "")
+ fmt.format(fmin / norm)
+ fmt.format(f0min / norm, fmin / norm)
+ "(normalized)"
)
print(
"Average "
+ ("absolute " if abserr else "")
+ fmt.format(fmean / norm)
+ fmt.format(f0mean / norm, fmean / norm)
+ "(normalized)"
)

formats = [
"Boundary normal field error: {:10.3e} ",
"Boundary magnetic pressure error: {:10.3e} ",
"Boundary field jump error: {:10.3e} ",
"Boundary normal field error: ",
"Boundary magnetic pressure error: ",
"Boundary field jump error: ",
]
units = ["(T*m^2)", "(T^2*m^2)", "(T*m^2)"]
if self._sheet_current:
Expand All @@ -761,14 +794,25 @@ def _print(fmt, fmax, fmin, fmean, norm, units):
norms = [self.normalization[0], self.normalization[nn]]
for i, (fmt, norm, unit) in enumerate(zip(formats, norms, units)):
fi = f[i * nn : (i + 1) * nn]
f0i = f0[i * nn : (i + 1) * nn]
# target == 0 probably indicates f is some sort of error metric,
# mean abs makes more sense than mean
fi = jnp.abs(fi) if abserr else fi
f0i = jnp.abs(f0i) if abserr else fi
wi = w[i * nn : (i + 1) * nn]
fmax = jnp.max(fi)
fmin = jnp.min(fi)
fmean = jnp.mean(fi * wi) / jnp.mean(wi)
_print(fmt, fmax, fmin, fmean, norm, unit)

f0max = jnp.max(f0i)
f0min = jnp.min(f0i)
f0mean = jnp.mean(f0i * wi) / jnp.mean(wi)
fmt = (
f"{fmt:<{PRINT_WIDTH-pre_width}}" + "{:10.3e} --> {:10.3e} "
if args0 is not None
else f"{fmt:<{PRINT_WIDTH-pre_width}}" + "{:10.3e} "
)
_print(fmt, fmax, fmin, fmean, f0max, f0min, f0mean, norm, unit)


class BoundaryErrorNESTOR(_Objective):
Expand Down Expand Up @@ -828,7 +872,7 @@ class BoundaryErrorNESTOR(_Objective):

_scalar = False
_linear = False
_print_value_fmt = "Boundary magnetic pressure error: {:10.3e} "
_print_value_fmt = "Boundary magnetic pressure error: "
_units = "(T^2*m^2)"
_coordinates = "rtz"

Expand Down
Loading

0 comments on commit 6326a4f

Please sign in to comment.