Skip to content

Commit

Permalink
Merge branch 'master' into ku/root_3d
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Aug 20, 2024
2 parents c49f31c + dd3f472 commit 236af07
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 54 deletions.
51 changes: 39 additions & 12 deletions desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from desc.utils import Index, errorif, flatten_list, svd_inv_null, unique_list, warnif


def factorize_linear_constraints(objective, constraint): # noqa: C901
def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa: C901
"""Compute and factorize A to get pseudoinverse and nullspace.
Given constraints of the form Ax=b, factorize A to find a particular solution xp
Expand All @@ -22,6 +22,10 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
Objective function to optimize.
constraint : ObjectiveFunction
Objective function of linear constraints to enforce.
x_scale : array_like or ``'auto'``, optional
Characteristic scale of each variable. Setting ``x_scale`` is equivalent
to reformulating the problem in scaled variables ``xs = x / x_scale``.
If set to ``'auto'``, the scale is determined from the initial state vector.
Returns
-------
Expand All @@ -33,6 +37,8 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
Combined RHS vector.
Z : ndarray
Null space operator for full combined A such that A @ Z == 0.
D : ndarray
Scale of the full state vector x, as set by the parameter ``x_scale``.
unfixed_idx : ndarray
Indices of x that correspond to non-fixed values.
project, recover : function
Expand Down Expand Up @@ -130,32 +136,53 @@ def factorize_linear_constraints(objective, constraint): # noqa: C901
)
A = A[unfixed_rows][:, unfixed_idx]
b = b[unfixed_rows]

unfixed_idx = indices_idx
fixed_idx = np.delete(np.arange(xp.size), unfixed_idx)

# compute x_scale if not provided
if x_scale == "auto":
x_scale = objective.x(*objective.things)
errorif(
x_scale.shape != xp.shape,
ValueError,
"x_scale must be the same size as the full state vector. "
+ f"Got size {x_scale.size} for state vector of size {xp.size}.",
)
D = np.where(np.abs(x_scale) < 1e2, 1, np.abs(x_scale))

# null space & particular solution
A = A * D[None, unfixed_idx]
if A.size:
Ainv_full, Z = svd_inv_null(A)
A_inv, Z = svd_inv_null(A)
else:
Ainv_full = A.T
A_inv = A.T
Z = np.eye(A.shape[1])
Ainv_full = jnp.asarray(Ainv_full)
Z = jnp.asarray(Z)
b = jnp.asarray(b)
xp = put(xp, unfixed_idx, Ainv_full @ b)
xp = put(xp, unfixed_idx, A_inv @ b)
xp = put(xp, fixed_idx, ((1 / D) * xp)[fixed_idx])

# cast to jnp arrays
xp = jnp.asarray(xp)
A = jnp.asarray(A)
b = jnp.asarray(b)
Z = jnp.asarray(Z)
D = jnp.asarray(D)

@jit
def project(x):
def project(x_full):
"""Project a full state vector into the reduced optimization vector."""
x_reduced = Z.T @ ((x - xp)[unfixed_idx])
x_reduced = Z.T @ ((1 / D) * x_full - xp)[unfixed_idx]
return jnp.atleast_1d(jnp.squeeze(x_reduced))

@jit
def recover(x_reduced):
"""Recover the full state vector from the reduced optimization vector."""
dx = put(jnp.zeros(objective.dim_x), unfixed_idx, Z @ x_reduced)
return jnp.atleast_1d(jnp.squeeze(xp + dx))
x_full = D * (xp + dx)
return jnp.atleast_1d(jnp.squeeze(x_full))

# check that all constraints are actually satisfiable
params = objective.unpack_state(xp, False)
params = objective.unpack_state(D * xp, False)
for con in constraint.objectives:
xpi = [params[i] for i, t in enumerate(objective.things) if t in con.things]
y1 = con.compute_unscaled(*xpi)
Expand Down Expand Up @@ -197,7 +224,7 @@ def recover(x_reduced):
"or be due to floating point error.",
)

return xp, A, b, Z, unfixed_idx, project, recover
return xp, A, b, Z, D, unfixed_idx, project, recover


def softmax(arr, alpha):
Expand Down
36 changes: 21 additions & 15 deletions desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def build(self, use_jit=None, verbose=1):
self._A,
self._b,
self._Z,
self._D,
self._unfixed_idx,
self._project,
self._recover,
Expand All @@ -113,10 +114,8 @@ def build(self, use_jit=None, verbose=1):
self._dim_x = self._objective.dim_x
self._dim_x_reduced = self._Z.shape[1]

# equivalent matrix for A[unfixed_idx]@Z == A@unfixed_idx_mat
self._unfixed_idx_mat = (
jnp.eye(self._objective.dim_x)[:, self._unfixed_idx] @ self._Z
)
# equivalent matrix for A[unfixed_idx] @ D @ Z == A @ unfixed_idx_mat
self._unfixed_idx_mat = jnp.diag(self._D)[:, self._unfixed_idx] @ self._Z

self._built = True
timer.stop("Linear constraint projection build")
Expand Down Expand Up @@ -261,7 +260,7 @@ def grad(self, x_reduced, constants=None):
"""
x = self.recover(x_reduced)
df = self._objective.grad(x, constants)
return df[self._unfixed_idx] @ self._Z
return df[self._unfixed_idx] @ (self._Z * self._D[self._unfixed_idx, None])

def hess(self, x_reduced, constants=None):
"""Compute Hessian of self.compute_scalar.
Expand All @@ -281,13 +280,19 @@ def hess(self, x_reduced, constants=None):
"""
x = self.recover(x_reduced)
df = self._objective.hess(x, constants)
return self._Z.T @ df[self._unfixed_idx, :][:, self._unfixed_idx] @ self._Z
return (
(self._Z.T * (1 / self._D)[None, self._unfixed_idx])
@ df[self._unfixed_idx, :][:, self._unfixed_idx]
@ (self._Z * self._D[self._unfixed_idx, None])
)

def _jac(self, x_reduced, constants=None, op="scaled"):
x = self.recover(x_reduced)
if self._objective._deriv_mode == "blocked":
fun = getattr(self._objective, "jac_" + op)
return fun(x, constants)[:, self._unfixed_idx] @ self._Z
return fun(x, constants)[:, self._unfixed_idx] @ (
self._Z * self._D[self._unfixed_idx, None]
)

v = self._unfixed_idx_mat
df = getattr(self._objective, "jvp_" + op)(v.T, x, constants)
Expand Down Expand Up @@ -401,7 +406,7 @@ def jvp_unscaled(self, v, x_reduced, constants=None):
def _vjp(self, v, x_reduced, constants=None, op="vjp_scaled"):
x = self.recover(x_reduced)
df = getattr(self._objective, op)(v, x, constants)
return df[self._unfixed_idx] @ self._Z
return df[self._unfixed_idx] @ (self._Z * self._D[self._unfixed_idx, None])

def vjp_scaled(self, v, x_reduced, constants=None):
"""Compute vector-Jacobian product of self.compute_scaled.
Expand Down Expand Up @@ -533,8 +538,8 @@ def _set_eq_state_vector(self):
self._args.remove(arg)
linear_constraint = ObjectiveFunction(self._linear_constraints)
linear_constraint.build()
_, A, _, self._Z, self._unfixed_idx, _, _ = factorize_linear_constraints(
self._constraint, linear_constraint
_, _, _, self._Z, self._D, self._unfixed_idx, _, _ = (
factorize_linear_constraints(self._constraint, linear_constraint)
)

# dx/dc - goes from the full state to optimization variables for eq
Expand Down Expand Up @@ -618,14 +623,14 @@ def build(self, use_jit=None, verbose=1): # noqa: C901
)
self._dimx_per_thing = [t.dim_x for t in self.things]

# equivalent matrix for A[unfixed_idx]@Z == A@unfixed_idx_mat
# equivalent matrix for A[unfixed_idx] @ D @ Z == A @ unfixed_idx_mat
self._unfixed_idx_mat = jnp.eye(self._objective.dim_x)
self._unfixed_idx_mat = jnp.split(
self._unfixed_idx_mat, np.cumsum([t.dim_x for t in self.things]), axis=-1
)
self._unfixed_idx_mat[self._eq_idx] = (
self._unfixed_idx_mat[self._eq_idx][:, self._unfixed_idx] @ self._Z
)
self._unfixed_idx_mat[self._eq_idx] = self._unfixed_idx_mat[self._eq_idx][
:, self._unfixed_idx
] @ (self._Z * self._D[self._unfixed_idx, None])
self._unfixed_idx_mat = np.concatenate(
[np.atleast_2d(foo) for foo in self._unfixed_idx_mat], axis=-1
)
Expand Down Expand Up @@ -1018,7 +1023,8 @@ def jvp_unscaled(self, v, x, constants=None):
@functools.partial(jit, static_argnames=("self", "op"))
def _jvp_f(self, xf, dc, constants, op):
Fx = getattr(self._constraint, "jac_" + op)(xf, constants)
Fx_reduced = Fx[:, self._unfixed_idx] @ self._Z
# TODO: replace with self._unfixed_idx_mat?
Fx_reduced = Fx @ jnp.diag(self._D)[:, self._unfixed_idx] @ self._Z
Fc = Fx @ (self._dxdc @ dc)
Fxh = Fx_reduced
cutoff = jnp.finfo(Fxh.dtype).eps * max(Fxh.shape)
Expand Down
20 changes: 8 additions & 12 deletions desc/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
if verbose > 0:
print("Factorizing linear constraints")
timer.start("linear constraint factorize")
xp, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
xp, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
)
timer.stop("linear constraint factorize")
Expand Down Expand Up @@ -291,7 +291,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
print("Computing df")
timer.start("df computation")
Jx = objective.jac_scaled_error(x)
Jx_reduced = Jx[:, unfixed_idx] @ Z @ scale
Jx_reduced = Jx @ jnp.diag(D)[:, unfixed_idx] @ Z @ scale
RHS1 = objective.jvp_scaled(tangents, x)
if include_f:
f = objective.compute_scaled_error(x)
Expand Down Expand Up @@ -388,9 +388,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces
con.update_target(eq_new)
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)
xp, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
)
_, _, _, _, _, _, _, recover = factorize_linear_constraints(objective, constraint)

# update other attributes
dx_reduced = dx1_reduced + dx2_reduced + dx3_reduced
Expand Down Expand Up @@ -547,7 +545,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)

_, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
_, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective_f, constraint
)

Expand All @@ -564,7 +562,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
dx2_reduced = 0

# dx/dx_reduced
dxdx_reduced = jnp.eye(eq.dim_x)[:, unfixed_idx] @ Z
dxdx_reduced = jnp.diag(D)[:, unfixed_idx] @ Z

# dx/dc
dxdc = []
Expand Down Expand Up @@ -612,8 +610,8 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
timer.disp("dg computation")

# projections onto optimization space
Fx_reduced = Fx[:, unfixed_idx] @ Z
Gx_reduced = Gx[:, unfixed_idx] @ Z
Fx_reduced = Fx @ jnp.diag(D)[:, unfixed_idx] @ Z
Gx_reduced = Gx @ jnp.diag(D)[:, unfixed_idx] @ Z
Fc = Fx @ dxdc
Gc = Gx @ dxdc

Expand Down Expand Up @@ -752,9 +750,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces
con.update_target(eq_new)
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)
_, _, _, Z, unfixed_idx, project, recover = factorize_linear_constraints(
objective_f, constraint
)
_, _, _, _, _, _, _, recover = factorize_linear_constraints(objective_f, constraint)

# update other attributes
dx_reduced = dx1_reduced + dx2_reduced
Expand Down
2 changes: 1 addition & 1 deletion desc/vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def load(
constraints = maybe_add_self_consistency(eq, constraints)
objective = ObjectiveFunction(constraints)
objective.build(verbose=0)
_, _, _, _, _, project, recover = factorize_linear_constraints(
_, _, _, _, _, _, project, recover = factorize_linear_constraints(
objective, objective
)
args = objective.unpack_state(recover(project(objective.x(eq))), False)[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ def test_bootstrap_optimization_comparison_qa():
objective=objective,
constraints=constraints,
optimizer="proximal-lsq-exact",
maxiter=4,
maxiter=5,
gtol=1e-16,
verbose=3,
)
Expand Down Expand Up @@ -1622,5 +1622,5 @@ def test_bootstrap_optimization_comparison_qa():
grid.compress(data2["<J*B>"]), grid.compress(data2["<J*B> Redl"]), rtol=1.8e-2
)
np.testing.assert_allclose(
grid.compress(data1["<J*B>"]), grid.compress(data2["<J*B>"]), rtol=1.8e-2
grid.compress(data1["<J*B>"]), grid.compress(data2["<J*B>"]), rtol=1.9e-2
)
24 changes: 12 additions & 12 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test_correct_indexing_passed_modes():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
)

Expand All @@ -461,8 +461,8 @@ def test_correct_indexing_passed_modes():
atol = 2e-15
np.testing.assert_allclose(x1, x2, atol=atol)
np.testing.assert_allclose(A @ xp[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x1[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x2[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ (x1[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ (x2[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ Z, 0, atol=atol)


Expand Down Expand Up @@ -514,7 +514,7 @@ def test_correct_indexing_passed_modes_and_passed_target():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
)

Expand All @@ -524,8 +524,8 @@ def test_correct_indexing_passed_modes_and_passed_target():
atol = 2e-15
np.testing.assert_allclose(x1, x2, atol=atol)
np.testing.assert_allclose(A @ xp[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x1[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x2[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ (x1[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ (x2[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ Z, 0, atol=atol)


Expand Down Expand Up @@ -574,7 +574,7 @@ def test_correct_indexing_passed_modes_axis():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
)

Expand All @@ -584,8 +584,8 @@ def test_correct_indexing_passed_modes_axis():
atol = 2e-15
np.testing.assert_allclose(x1, x2, atol=atol)
np.testing.assert_allclose(A @ xp[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x1[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x2[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ (x1[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ (x2[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ Z, 0, atol=atol)


Expand Down Expand Up @@ -703,7 +703,7 @@ def test_correct_indexing_passed_modes_and_passed_target_axis():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
objective, constraint
)

Expand All @@ -713,8 +713,8 @@ def test_correct_indexing_passed_modes_and_passed_target_axis():
atol = 2e-15
np.testing.assert_allclose(x1, x2, atol=atol)
np.testing.assert_allclose(A @ xp[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x1[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ x2[unfixed_idx], b, atol=atol)
np.testing.assert_allclose(A @ (x1[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ (x2[unfixed_idx] / D[unfixed_idx]), b, atol=atol)
np.testing.assert_allclose(A @ Z, 0, atol=atol)


Expand Down
Loading

0 comments on commit 236af07

Please sign in to comment.