Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ADMM + parameter reordering FISTA #553

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cuqi/experimental/mcmc/_rto.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def prior(self):

def step(self):
y = self.b_tild + np.random.randn(len(self.b_tild))
sim = FISTA(self.M, y, self.current_point, self.proximal,
maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
sim = FISTA(self.M, y, self.proximal,
self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
self.current_point, _ = sim.solve()
acc = 1
return acc
4 changes: 2 additions & 2 deletions cuqi/sampler/_rto.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def _sample(self, N, Nb):
samples[:, 0] = self.x0
for s in range(Ns-1):
y = self.b_tild + np.random.randn(len(self.b_tild))
sim = FISTA(self.M, y, samples[:, s], self.proximal,
maxit = self.maxit, stepsize = _stepsize, abstol = self.abstol, adaptive = self.adaptive)
sim = FISTA(self.M, y, self.proximal,
samples[:, s], maxit = self.maxit, stepsize = _stepsize, abstol = self.abstol, adaptive = self.adaptive)
samples[:, s+1], _ = sim.solve()

self._print_progress(s+2,Ns) #s+2 is the sample number, s+1 is index assuming x0 is the first sample
Expand Down
1 change: 1 addition & 0 deletions cuqi/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
LM,
PDHG,
FISTA,
ADMM,
ProjectNonnegative,
ProjectBox,
ProximalL1
Expand Down
164 changes: 160 additions & 4 deletions cuqi/solver/_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,8 @@ class FISTA(object):
----------
A : ndarray or callable f(x,*args).
b : ndarray.
x0 : ndarray. Initial guess.
proximal : callable f(x, gamma) for proximal mapping.
x0 : ndarray. Initial guess.
maxit : The maximum number of iterations.
stepsize : The stepsize of the gradient step.
abstol : The numerical tolerance for convergence checks.
Expand All @@ -606,11 +606,11 @@ class FISTA(object):
b = rng.standard_normal(m)
stepsize = 0.99/(sp.linalg.interpolative.estimate_spectral_norm(A)**2)
x0 = np.zeros(n)
fista = FISTA(A, b, x0, proximal = ProximalL1, stepsize = stepsize, maxit = 100, abstol=1e-12, adaptive = True)
fista = FISTA(A, b, proximal = ProximalL1, x0, stepsize = stepsize, maxit = 100, abstol=1e-12, adaptive = True)
sol, _ = fista.solve()

"""
def __init__(self, A, b, x0, proximal, maxit=100, stepsize=1e0, abstol=1e-14, adaptive = True):
def __init__(self, A, b, proximal, x0, maxit=100, stepsize=1e0, abstol=1e-14, adaptive = True):

self.A = A
self.b = b
Expand Down Expand Up @@ -650,8 +650,148 @@ def solve(self):
x_new = x_new + ((k-1)/(k+2))*(x_new - x_old)

x = x_new.copy()

class ADMM(object):
"""Alternating Direction Method of Multipliers for solving regularized linear least squares problems of the form:
Minimize ||Ax-b||^2 + sum_i f_i(L_i x).
amal-ghamdi marked this conversation as resolved.
Show resolved Hide resolved

Reference:
[1] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.


Parameters
----------
A : ndarray or callable f(x,*args).
jeverink marked this conversation as resolved.
Show resolved Hide resolved
b : ndarray.
penalties : List of tuples (callable proximal operator of f_i, linear operator L_i).
x0 : ndarray. Initial guess.
tradeoff : Trade-off between linear least squares and regularization term in the solver iterates.
jeverink marked this conversation as resolved.
Show resolved Hide resolved
maxit : The maximum number of iterations.
adapative : Whether to adaptively update the tradeoff parameter each iteration. Based on [1], Subsection 3.4.1

Example
-----------
.. code-block:: python

from cuqi.solver import ADMM, ProximalL1, ProjectNonnegative
import scipy as sp
jeverink marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np

rng = np.random.default_rng()

m, n = 10, 5
A = rng.standard_normal((m, n))
b = rng.standard_normal(m)
x0 = np.zeros(n)
admm = ADMM(A, b, x0, penalties = [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z))], tradeoff = 10)
jeverink marked this conversation as resolved.
Show resolved Hide resolved
sol, _ = admm.solve()

"""

def __init__(self, A, b, penalties, x0, tradeoff = 10, maxit = 100, inner_max_it = 10, adaptive = True):

self.A = A
nabriis marked this conversation as resolved.
Show resolved Hide resolved
self.b = b
self.x_cur = x0

dual_len = [penalty[1].shape[0] for penalty in penalties]
self.y_cur = [np.zeros(l) for l in dual_len]
self.u_cur = [np.zeros(l) for l in dual_len]
self.n = penalties[0][1].shape[1]

self.rho = tradeoff
nabriis marked this conversation as resolved.
Show resolved Hide resolved
self.maxit = maxit
self.inner_max_it = inner_max_it
self.adaptive = adaptive

self.penalties = penalties
nabriis marked this conversation as resolved.
Show resolved Hide resolved
amal-ghamdi marked this conversation as resolved.
Show resolved Hide resolved

self.p = len(self.penalties)
self._big_matrix = None

def solve(self):
y_new = self.p*[0]
jeverink marked this conversation as resolved.
Show resolved Hide resolved
u_new = self.p*[0]

# Iterating
for i in range(self.maxit):
big_matrix, big_vector = self._iteration_pre_processing()
jeverink marked this conversation as resolved.
Show resolved Hide resolved

# Main update (Least Squares)
solver = CGLS(big_matrix, big_vector, self.x_cur, self.inner_max_it)
x_new, _ = solver.solve()

# Regularization update
for j, penalty in enumerate(self.penalties):
y_new[j] = penalty[0](penalty[1]@x_new + self.u_cur[j], 1.0/self.rho)

res_primal = 0.0
# Dual update
for j, penalty in enumerate(self.penalties):
r_partial = penalty[1]@x_new - y_new[j]
res_primal += LA.norm(r_partial)**2

u_new[j] = self.u_cur[j] + r_partial

res_dual = 0.0
for j, penalty in enumerate(self.penalties):
res_dual += LA.norm(penalty[1].T@(y_new[j] - self.y_cur[j]))**2

# Adaptive approach based on [1], Subsection 3.4.1
if self.adaptive:
jeverink marked this conversation as resolved.
Show resolved Hide resolved
if res_dual > 1e2*res_primal:
amal-ghamdi marked this conversation as resolved.
Show resolved Hide resolved
self.rho *= 0.5 # More regularization
elif res_primal > 1e2*res_dual:
self.rho *= 2.0 # More data fidelity

self.x_cur, self.y_cur, self.u_cur = x_new, y_new.copy(), u_new

return self.x_cur, i

def _iteration_pre_processing(self):
""" Preprocessing
Every iteration of ADMM requires solving a linear least squares system of the form
minimize 1/(rho) \|Ax-b\|_2^2 + sum_{i=1}^{p} \|penalty[1]x - (y - u)\|_2^2
To solve this, all linear least squares terms are combined into a single big term
with matrix big_matrix and data big_vector.

The matrix only needs to be updated when rho changes, i.e., when the adaptive option is used.
The data vector needs to be updated every iteration.
"""

big_vector = np.hstack([np.sqrt(1/self.rho)*self.b] + [self.y_cur[i] - self.u_cur[i] for i in range(self.p)])

# Check whether matrix needs to be updated
if self._big_matrix is not None and not self.adaptive:
return self._big_matrix, big_vector

# Update big_matrix
if callable(self.A):
def matrix_eval(x, flag):
if flag == 1:
out1 = np.sqrt(1/self.rho)*self.A(x, 1)
out2 = [penalty[1]@x for penalty in self.penalties]
out = np.hstack([out1] + out2)
elif flag == 2:
idx_start = len(x)
idx_end = len(x)
out1 = np.zeros(self.n)
for _, t in reversed(self.penalties):
idx_start -= t.shape[0]
out1 += t.T@x[idx_start:idx_end]
idx_end = idx_start
out2 = np.sqrt(1/self.rho)*self.A(x[:idx_end], 2)
out = out1 + out2
return out
self._big_matrix = matrix_eval
else:
self._big_matrix = np.vstack([np.sqrt(1/self.rho)*self.A] + [penalty[1] for penalty in self.penalties])

return self._big_matrix, big_vector




def ProjectNonnegative(x):
"""(Euclidean) projection onto the nonnegative orthant.

Expand All @@ -678,6 +818,22 @@ def ProjectBox(x, lower = None, upper = None):

return np.minimum(np.maximum(x, lower), upper)

def ProjectHalfspace(x, a, b):
"""(Euclidean) projection onto the halfspace defined {z|<a,z> <= b}.

Parameters
----------
x : array_like.
a : array_like.
b : array_like.
"""

ax_b = np.inner(a,x) - b
if ax_b <= 0:
return x
else:
return x - (ax_b/np.inner(a,a))*a

def ProximalL1(x, gamma):
"""(Euclidean) proximal operator of the \|x\|_1 norm.
Also known as the shrinkage or soft thresholding operator.
Expand All @@ -687,4 +843,4 @@ def ProximalL1(x, gamma):
x : array_like.
gamma : scale parameter.
"""
return np.multiply(np.sign(x), np.maximum(np.abs(x)-gamma, 0))
return np.multiply(np.sign(x), np.maximum(np.abs(x)-gamma, 0))
54 changes: 51 additions & 3 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import scipy as sp

from cuqi.solver import CGLS, LM, FISTA, ProximalL1
from cuqi.solver import CGLS, LM, FISTA, ADMM, ProximalL1, ProjectNonnegative
from scipy.optimize import lsq_linear


Expand Down Expand Up @@ -54,8 +54,56 @@ def test_FISTA():

stepsize = 0.99/(sp.linalg.interpolative.estimate_spectral_norm(A)**2)
x0 = np.zeros(n)
sol, _ = FISTA(A, b, x0, proximal = ProximalL1, stepsize = stepsize, maxit = 100, abstol=1e-12, adaptive = True).solve()
sol, _ = FISTA(A, b, ProximalL1, x0, stepsize = stepsize, maxit = 100, abstol=1e-12, adaptive = True).solve()

ref_sol = np.array([-1.83273787e-03, -1.72094582e-13, 0.0, -3.35835639e-01, -1.27795593e-01])
# Compare
assert np.allclose(sol, ref_sol, atol=1e-4)
assert np.allclose(sol, ref_sol, atol=1e-4)

def test_ADMM_matrix_form():
# Parameters
rng = np.random.default_rng(seed = 42)
m, n = 10, 5
A = rng.standard_normal((m, n))
b = rng.standard_normal(m)

k = 4
L = rng.standard_normal((k, n))

stepsize = 0.99/(sp.linalg.interpolative.estimate_spectral_norm(A)**2)
x0 = np.zeros(n)
sol, _ = ADMM(A, b, [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z), L)],
x0, 10, maxit = 100, adaptive = True).solve()

ref_sol = np.array([-3.99513417e-03, -1.32339656e-01, -4.52822633e-02, -7.44973888e-02, -3.35005208e-11])
# Compare
assert np.allclose(sol, ref_sol, atol=1e-4)


def test_ADMM_function_form():
# Parameters
rng = np.random.default_rng(seed = 42)
m, n = 10, 5
A = rng.standard_normal((m, n))
def A_fun(x, flag):
if flag == 1:
return A@x
if flag == 2:
return A.T@x

b = rng.standard_normal(m)

k = 4
L = rng.standard_normal((k, n))

stepsize = 0.99/(sp.linalg.interpolative.estimate_spectral_norm(A)**2)
x0 = np.zeros(n)
sol, _ = ADMM(A_fun, b, [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z), L)],
x0, 10, maxit = 100, adaptive = True).solve()

print(sol)
ref_sol = np.array([-3.99513417e-03, -1.32339656e-01, -4.52822633e-02, -7.44973888e-02, -3.35005208e-11])
# Compare
assert np.allclose(sol, ref_sol, atol=1e-4)

jeverink marked this conversation as resolved.
Show resolved Hide resolved
test_ADMM_function_form()
Loading