Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaled Linear Constraints #1158

Merged
merged 36 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c09443e
coil current scale attribute
daniel-dudt Jul 31, 2024
e0d8b9b
fixing init methods
daniel-dudt Jul 31, 2024
c674795
very simple test
daniel-dudt Jul 31, 2024
73c3007
revert all commits LOL
daniel-dudt Aug 9, 2024
6a76eb0
add scaling matrix to factorize_linear_constraints
daniel-dudt Aug 9, 2024
82ac97f
return diag scaling matrix D
daniel-dudt Aug 9, 2024
a778b61
Merge branch 'master' into dd/current_scale
ddudt Aug 9, 2024
964c72c
improved scaling
daniel-dudt Aug 9, 2024
2a61202
change D to rescale x_reduced instead of x_full
daniel-dudt Aug 12, 2024
5bf24c7
add D matrix to perturbation logic
daniel-dudt Aug 12, 2024
9a71577
Merge branch 'master' into dd/current_scale
ddudt Aug 12, 2024
869f408
improved scaling
daniel-dudt Aug 13, 2024
240752c
added test
daniel-dudt Aug 13, 2024
127d453
looks like it's working!
daniel-dudt Aug 13, 2024
a30fc3f
fixed indexing
daniel-dudt Aug 13, 2024
0ac5456
also rescale fixed values of xp
daniel-dudt Aug 14, 2024
83095ff
Merge branch 'master' into dd/current_scale
ddudt Aug 14, 2024
bc6fe03
making improvements
daniel-dudt Aug 14, 2024
ef76206
Merge branch 'master' into dd/current_scale
ddudt Aug 14, 2024
40cf8f2
making fixes
daniel-dudt Aug 15, 2024
f846496
remove changes to HELIOTRON example
daniel-dudt Aug 15, 2024
20303c3
add custom/auto x_scale
daniel-dudt Aug 15, 2024
7624f34
fix remaining tests
daniel-dudt Aug 15, 2024
8ed20d7
update docstring
daniel-dudt Aug 15, 2024
0d3c24e
Merge branch 'master' into dd/current_scale
ddudt Aug 15, 2024
461c0ef
increase threshold to 20
daniel-dudt Aug 15, 2024
56103e6
all tests hopefully passing now?
daniel-dudt Aug 15, 2024
b9e9a9d
Merge branch 'master' into dd/current_scale
ddudt Aug 15, 2024
6187f56
making requested changes
daniel-dudt Aug 16, 2024
6b9233a
fix bug from last commit
daniel-dudt Aug 16, 2024
e65a07d
Merge branch 'master' into dd/current_scale
dpanici Aug 16, 2024
cc1ddee
always use default things
daniel-dudt Aug 16, 2024
92e5cf3
Merge branch 'master' into dd/current_scale
f0uriest Aug 17, 2024
1a60814
Merge branch 'master' into dd/current_scale
ddudt Aug 18, 2024
d386a35
Merge branch 'dd/current_scale' of https://github.com/PlasmaControl/D…
daniel-dudt Aug 19, 2024
ddb63d2
always use objective.things for auto scale
daniel-dudt Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 59 additions & 15 deletions desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,20 @@
import numpy as np

from desc.backend import cond, jit, jnp, logsumexp, put
from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif


def factorize_linear_constraints(objective, constraint): # noqa: C901
from desc.utils import (
Index,
errorif,
flatten_list,
get_instance,
svd_inv_null,
unique_list,
warnif,
)


def factorize_linear_constraints( # noqa: C901
objective, constraint, things=None, x_scale="auto"
):
"""Compute and factorize A to get pseudoinverse and nullspace.

Given constraints of the form Ax=b, factorize A to find a particular solution xp
Expand All @@ -22,6 +32,13 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
Objective function to optimize.
constraint : ObjectiveFunction
Objective function of linear constraints to enforce.
things : Optimizable or tuple/list of Optimizable
ddudt marked this conversation as resolved.
Show resolved Hide resolved
Things to optimize. Defaults to ``objective.things``.
Only used if ``x_scale='auto'``.
x_scale : array_like or ``'auto'``, optional
Characteristic scale of each variable. Setting ``x_scale`` is equivalent
to reformulating the problem in scaled variables ``xs = x / x_scale``.
If set to ``'auto'``, the scale is determined from the initial state vector.

Returns
-------
Expand All @@ -33,6 +50,8 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
Combined RHS vector.
Z : ndarray
Null space operator for full combined A such that A @ Z == 0.
D : ndarray
Scale of the full state vector x, based on the particular solution xp.
unfixed_idx : ndarray
Indices of x that correspond to non-fixed values.
project, recover : function
Expand Down Expand Up @@ -130,32 +149,57 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
)
A = A[unfixed_rows][:, unfixed_idx]
b = b[unfixed_rows]

unfixed_idx = indices_idx
fixed_idx = np.delete(np.arange(xp.size), unfixed_idx)

# compute x_scale if not provided
if x_scale == "auto":
if things is None:
things = objective.things
else:
things = [things] if not isinstance(things, list) else things
things = [get_instance(things, type(t)) for t in objective.things]
ddudt marked this conversation as resolved.
Show resolved Hide resolved
x_scale = objective.x(*things)
errorif(
x_scale.shape != xp.shape,
ValueError,
"x_scale must be the same size as the full state vector.",
ddudt marked this conversation as resolved.
Show resolved Hide resolved
)
D = np.where(np.abs(x_scale) < 1e1, 1, np.abs(x_scale))

# null space & particular solution
A = A * D[None, unfixed_idx]
if A.size:
Ainv_full, Z = svd_inv_null(A)
A_inv, Z = svd_inv_null(A)
else:
Ainv_full = A.T
A_inv = A.T
Z = np.eye(A.shape[1])
Ainv_full = jnp.asarray(Ainv_full)
Z = jnp.asarray(Z)
b = jnp.asarray(b)
xp = put(xp, unfixed_idx, Ainv_full @ b)
xp = put(xp, unfixed_idx, A_inv @ b)
xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx])

# cast to jnp arrays
xp = jnp.asarray(xp)
A = jnp.asarray(A)
b = jnp.asarray(b)
Z = jnp.asarray(Z)
D = jnp.asarray(D)

@jit
def project(x):
def project(x_full):
"""Project a full state vector into the reduced optimization vector."""
x_reduced = Z.T @ ((x - xp)[unfixed_idx])
x_reduced = Z.T @ ((1 / D) * x_full - xp)[unfixed_idx]
return jnp.atleast_1d(jnp.squeeze(x_reduced))

@jit
def recover(x_reduced):
"""Recover the full state vector from the reduced optimization vector."""
dx = put(jnp.zeros(objective.dim_x), unfixed_idx, Z @ x_reduced)
return jnp.atleast_1d(jnp.squeeze(xp + dx))
x_full = D * (xp + dx)
return jnp.atleast_1d(jnp.squeeze(x_full))

# check that all constraints are actually satisfiable
params = objective.unpack_state(xp, False)
params = objective.unpack_state(D * xp, False)
for con in constraint.objectives:
xpi = [params[i] for i, t in enumerate(objective.things) if t in con.things]
y1 = con.compute_unscaled(*xpi)
Expand Down Expand Up @@ -197,7 +241,7 @@ def recover(x_reduced):
"or be due to floating point error.",
)

return xp, A, b, Z, unfixed_idx, project, recover
return xp, A, b, Z, D, unfixed_idx, project, recover


def softmax(arr, alpha):
Expand Down
36 changes: 21 additions & 15 deletions desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def build(self, use_jit=None, verbose=1):
self._A,
self._b,
self._Z,
self._D,
self._unfixed_idx,
self._project,
self._recover,
Expand All @@ -113,10 +114,8 @@ def build(self, use_jit=None, verbose=1):
self._dim_x = self._objective.dim_x
self._dim_x_reduced = self._Z.shape[1]

# equivalent matrix for A[unfixed_idx]@Z == A@unfixed_idx_mat
self._unfixed_idx_mat = (
jnp.eye(self._objective.dim_x)[:, self._unfixed_idx] @ self._Z
)
# equivalent matrix for A[unfixed_idx] @ D @ Z == A @ unfixed_idx_mat
self._unfixed_idx_mat = jnp.diag(self._D)[:, self._unfixed_idx] @ self._Z

self._built = True
timer.stop("Linear constraint projection build")
Expand Down Expand Up @@ -261,7 +260,7 @@ def grad(self, x_reduced, constants=None):
"""
x = self.recover(x_reduced)
df = self._objective.grad(x, constants)
return df[self._unfixed_idx] @ self._Z
return df @ self._unfixed_idx_mat

def hess(self, x_reduced, constants=None):
"""Compute Hessian of self.compute_scalar.
Expand All @@ -281,13 +280,17 @@ def hess(self, x_reduced, constants=None):
"""
x = self.recover(x_reduced)
df = self._objective.hess(x, constants)
return self._Z.T @ df[self._unfixed_idx, :][:, self._unfixed_idx] @ self._Z
return (
(self._Z.T * (1 / self._D)[None, self._unfixed_idx])
@ df[self._unfixed_idx, :][:, self._unfixed_idx]
@ (self._Z * self._D[self._unfixed_idx, None])
)

def _jac(self, x_reduced, constants=None, op="scaled"):
x = self.recover(x_reduced)
if self._objective._deriv_mode == "blocked":
fun = getattr(self._objective, "jac_" + op)
return fun(x, constants)[:, self._unfixed_idx] @ self._Z
return fun(x, constants) @ self._unfixed_idx_mat
ddudt marked this conversation as resolved.
Show resolved Hide resolved

v = self._unfixed_idx_mat
df = getattr(self._objective, "jvp_" + op)(v.T, x, constants)
Expand Down Expand Up @@ -401,7 +404,7 @@ def jvp_unscaled(self, v, x_reduced, constants=None):
def _vjp(self, v, x_reduced, constants=None, op="vjp_scaled"):
x = self.recover(x_reduced)
df = getattr(self._objective, op)(v, x, constants)
return df[self._unfixed_idx] @ self._Z
return df @ self._unfixed_idx_mat

def vjp_scaled(self, v, x_reduced, constants=None):
"""Compute vector-Jacobian product of self.compute_scaled.
Expand Down Expand Up @@ -533,8 +536,10 @@ def _set_eq_state_vector(self):
self._args.remove(arg)
linear_constraint = ObjectiveFunction(self._linear_constraints)
linear_constraint.build()
_, A, _, self._Z, self._unfixed_idx, _, _ = factorize_linear_constraints(
self._constraint, linear_constraint
_, _, _, self._Z, self._D, self._unfixed_idx, _, _ = (
factorize_linear_constraints(
self._constraint, linear_constraint, things=[self._eq]
)
)

# dx/dc - goes from the full state to optimization variables for eq
Expand Down Expand Up @@ -618,14 +623,14 @@ def build(self, use_jit=None, verbose=1): # noqa: C901
)
self._dimx_per_thing = [t.dim_x for t in self.things]

# equivalent matrix for A[unfixed_idx]@Z == A@unfixed_idx_mat
# equivalent matrix for A[unfixed_idx] @ D @ Z == A @ unfixed_idx_mat
self._unfixed_idx_mat = jnp.eye(self._objective.dim_x)
self._unfixed_idx_mat = jnp.split(
self._unfixed_idx_mat, np.cumsum([t.dim_x for t in self.things]), axis=-1
)
self._unfixed_idx_mat[self._eq_idx] = (
self._unfixed_idx_mat[self._eq_idx][:, self._unfixed_idx] @ self._Z
)
self._unfixed_idx_mat[self._eq_idx] = self._unfixed_idx_mat[self._eq_idx][
:, self._unfixed_idx
] @ (self._Z * self._D[self._unfixed_idx, None])
self._unfixed_idx_mat = np.concatenate(
[np.atleast_2d(foo) for foo in self._unfixed_idx_mat], axis=-1
)
Expand Down Expand Up @@ -1018,7 +1023,8 @@ def jvp_unscaled(self, v, x, constants=None):
@functools.partial(jit, static_argnames=("self", "op"))
def _jvp_f(self, xf, dc, constants, op):
Fx = getattr(self._constraint, "jac_" + op)(xf, constants)
Fx_reduced = Fx[:, self._unfixed_idx] @ self._Z
# TODO: replace with self._unfixed_idx_mat?
Fx_reduced = Fx @ jnp.diag(self._D)[:, self._unfixed_idx] @ self._Z
Fc = Fx @ (self._dxdc @ dc)
Fxh = Fx_reduced
cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape)
Expand Down
24 changes: 12 additions & 12 deletions desc/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
if verbose > 0:
print("Factorizing linear constraints")
timer.start("linear constraint factorize")
xp, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
xp, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint, things=eq
)
timer.stop("linear constraint factorize")
if verbose > 1:
Expand Down Expand Up @@ -291,7 +291,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
print("Computing df")
timer.start("df computation")
Jx = objective.jac_scaled_error(x)
Jx_reduced = Jx[:, unfixed_idx] @ Z @ scale
Jx_reduced = Jx @ jnp.diag(D)[:, unfixed_idx] @ Z @ scale
RHS1 = objective.jvp_scaled(tangents, x)
if include_f:
f = objective.compute_scaled_error(x)
Expand Down Expand Up @@ -388,8 +388,8 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
con.update_target(eq_new)
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)
xp, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
_, _, _, _, _, _, _, recover = factorize_linear_constraints(
objective, constraint, things=[eq_new]
)

# update other attributes
Expand Down Expand Up @@ -547,8 +547,8 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)

_, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective_f, constraint
_, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective_f, constraint, things=eq
)

# state vector
Expand All @@ -564,7 +564,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
dx2_reduced = 0

# dx/dx_reduced
dxdx_reduced = jnp.eye(eq.dim_x)[:, unfixed_idx] @ Z
dxdx_reduced = jnp.diag(D)[:, unfixed_idx] @ Z

# dx/dc
dxdc = []
Expand Down Expand Up @@ -612,8 +612,8 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
timer.disp("dg computation")

# projections onto optimization space
Fx_reduced = Fx[:, unfixed_idx] @ Z
Gx_reduced = Gx[:, unfixed_idx] @ Z
Fx_reduced = Fx @ jnp.diag(D)[:, unfixed_idx] @ Z
Gx_reduced = Gx @ jnp.diag(D)[:, unfixed_idx] @ Z
Fc = Fx @ dxdc
Gc = Gx @ dxdc

Expand Down Expand Up @@ -752,8 +752,8 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
con.update_target(eq_new)
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)
_, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective_f, constraint
_, _, _, _, _, _, _, recover = factorize_linear_constraints(
objective_f, constraint, things=[eq_new]
)

# update other attributes
Expand Down
2 changes: 1 addition & 1 deletion desc/vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def load(
constraints = maybe_add_self_consistency(eq, constraints)
objective = ObjectiveFunction(constraints)
objective.build(verbose=0)
_, _, _, _, _, project, recover = factorize_linear_constraints(
_, _, _, _, _, _, project, recover = factorize_linear_constraints(
objective, objective
)
args = objective.unpack_state(recover(project(objective.x(eq))), False)[0]
Expand Down
32 changes: 16 additions & 16 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ def test_correct_indexing_passed_modes():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint, things=eq
)

x1 = objective.x(eq)
Expand All @@ -461,8 +461,8 @@ def test_correct_indexing_passed_modes():
atol = 2e-15
np.testing.assert_allclose(x1, x2, atol=atol)
np.testing.assert_allclose(A @ xp[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x1[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x2[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ (x1[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ (x2[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ Z, 0, atol=atol)


Expand Down Expand Up @@ -514,8 +514,8 @@ def test_correct_indexing_passed_modes_and_passed_target():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint, things=eq
)

x1 = objective.x(eq)
Expand All @@ -524,8 +524,8 @@ def test_correct_indexing_passed_modes_and_passed_target():
atol = 2e-15
np.testing.assert_allclose(x1, x2, atol=atol)
np.testing.assert_allclose(A @ xp[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x1[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x2[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ (x1[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ (x2[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ Z, 0, atol=atol)


Expand Down Expand Up @@ -574,8 +574,8 @@ def test_correct_indexing_passed_modes_axis():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint, things=eq
)

x1 = objective.x(eq)
Expand All @@ -584,8 +584,8 @@ def test_correct_indexing_passed_modes_axis():
atol = 2e-15
np.testing.assert_allclose(x1, x2, atol=atol)
np.testing.assert_allclose(A @ xp[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x1[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x2[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ (x1[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ (x2[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ Z, 0, atol=atol)


Expand Down Expand Up @@ -703,8 +703,8 @@ def test_correct_indexing_passed_modes_and_passed_target_axis():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint, things=eq
)

x1 = objective.x(eq)
Expand All @@ -713,8 +713,8 @@ def test_correct_indexing_passed_modes_and_passed_target_axis():
atol = 2e-15
np.testing.assert_allclose(x1, x2, atol=atol)
np.testing.assert_allclose(A @ xp[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x1[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x2[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ (x1[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ (x2[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ Z, 0, atol=atol)


Expand Down
Loading
Loading