Skip to content

Commit

Permalink
Merge branch 'master' into yge/print
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma authored Aug 21, 2024
2 parents bccf5fd + 425fb02 commit 7bd3df7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
14 changes: 12 additions & 2 deletions desc/optimize/aug_lagrangian_ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from scipy.optimize import NonlinearConstraint, OptimizeResult

from desc.backend import jnp
from desc.backend import jnp, qr
from desc.utils import errorif, setdefault

from .bound_utils import (
Expand All @@ -25,6 +25,7 @@
inequality_to_bounds,
print_header_nonlinear,
print_iteration_nonlinear,
solve_triangular_regularized,
)


Expand Down Expand Up @@ -368,6 +369,15 @@ def lagjac(z, y, mu, *args):
U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False)
elif tr_method == "cho":
B_h = jnp.dot(J_a.T, J_a)
elif tr_method == "qr":
# try full newton step
tall = J_a.shape[0] >= J_a.shape[1]
if tall:
Q, R = qr(J_a, mode="economic")
p_newton = solve_triangular_regularized(R, -Q.T @ L_a)
else:
Q, R = qr(J_a.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -L_a, lower=True)

actual_reduction = -1
Lactual_reduction = -1
Expand All @@ -390,7 +400,7 @@ def lagjac(z, y, mu, *args):
)
elif tr_method == "qr":
step_h, hits_boundary, alpha = trust_region_step_exact_qr(
L_a, J_a, trust_radius, alpha
p_newton, L_a, J_a, trust_radius, alpha
)

step = d * step_h # Trust-region solution in the original space.
Expand Down
14 changes: 12 additions & 2 deletions desc/optimize/least_squares.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from scipy.optimize import OptimizeResult

from desc.backend import jnp
from desc.backend import jnp, qr
from desc.utils import errorif, setdefault

from .bound_utils import (
Expand All @@ -24,6 +24,7 @@
compute_jac_scale,
print_header_nonlinear,
print_iteration_nonlinear,
solve_triangular_regularized,
)


Expand Down Expand Up @@ -268,6 +269,15 @@ def lsqtr( # noqa: C901 - FIXME: simplify this
U, s, Vt = jnp.linalg.svd(J_a, full_matrices=False)
elif tr_method == "cho":
B_h = jnp.dot(J_a.T, J_a)
elif tr_method == "qr":
# try full newton step
tall = J_a.shape[0] >= J_a.shape[1]
if tall:
Q, R = qr(J_a, mode="economic")
p_newton = solve_triangular_regularized(R, -Q.T @ f_a)
else:
Q, R = qr(J_a.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -f_a, lower=True)

actual_reduction = -1

Expand All @@ -289,7 +299,7 @@ def lsqtr( # noqa: C901 - FIXME: simplify this
)
elif tr_method == "qr":
step_h, hits_boundary, alpha = trust_region_step_exact_qr(
f_a, J_a, trust_radius, alpha
p_newton, f_a, J_a, trust_radius, alpha
)
step = d * step_h # Trust-region solution in the original space.

Expand Down
12 changes: 3 additions & 9 deletions desc/optimize/tr_subproblems.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def loop_body(state):

@jit
def trust_region_step_exact_qr(
f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10
p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10
):
"""Solve a trust-region problem using a semi-exact method.
Expand Down Expand Up @@ -414,14 +414,6 @@ def trust_region_step_exact_qr(
Sometimes called Levenberg-Marquardt parameter.
"""
# try full newton step
tall = J.shape[0] >= J.shape[1]
if tall:
Q, R = qr(J, mode="economic")
p_newton = solve_triangular_regularized(R, -Q.T @ f)
else:
Q, R = qr(J.T, mode="economic")
p_newton = Q @ solve_triangular_regularized(R.T, -f, lower=True)

def truefun(*_):
return p_newton, False, 0.0
Expand Down Expand Up @@ -453,6 +445,7 @@ def loop_body(state):
Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
# Ji is always tall since its padded by alpha*I
Q, R = qr(Ji, mode="economic")

p = solve_triangular_regularized(R, -Q.T @ fp)
p_norm = jnp.linalg.norm(p)
phi = p_norm - trust_radius
Expand All @@ -474,6 +467,7 @@ def loop_body(state):
alpha, *_ = while_loop(
loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k)
)

Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
Q, R = qr(Ji, mode="economic")
p = solve_triangular(R, -Q.T @ fp)
Expand Down

0 comments on commit 7bd3df7

Please sign in to comment.