Skip to content

Commit

Permalink
Misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Apr 2, 2020
1 parent 8e0bfd8 commit 002fc6f
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 57 deletions.
13 changes: 7 additions & 6 deletions examples/solvers/douglas_rachford_pd_heron.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,15 @@ def print_objective(x):
tau=tau, sigma=sigma, niter=20, lam=lam,
callback=print_objective, l=l)

# plot the result
# Plot the result
fig, ax = plt.subplots()
for minp, maxp in rectangles:
xp = [minp[0], maxp[0], maxp[0], minp[0], minp[0]]
yp = [minp[1], minp[1], maxp[1], maxp[1], minp[1]]
plt.plot(xp, yp)
ax.plot(xp, yp)

plt.scatter(x[0], x[1])
ax.scatter(x[0], x[1])

plt.xlim(-1, 4)
plt.ylim(-1, 4)
plt.show()
ax.set_xlim(-1, 4)
ax.set_ylim(-1, 4)
fig.show()
10 changes: 5 additions & 5 deletions examples/solvers/douglas_rachford_pd_mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
# Create noisy MRI data
phantom = odl.phantom.shepp_logan(space, modified=True)
noisy_data = mri_op(phantom) + odl.phantom.white_noise(mri_op.range) * 0.1
phantom.show('Phantom')
noisy_data.show('Noisy MRI Data')
space.show(phantom, 'Phantom')
ft.range.show(noisy_data, 'Noisy MRI Data')

# Gradient for TV regularization
gradient = odl.Gradient(space)
Expand All @@ -44,11 +44,11 @@

# Solve
x = mri_op.domain.zero()
callback = (odl.solvers.CallbackShow(step=5, clim=[0, 1]) &
callback = (odl.solvers.CallbackShow(space, step=5, clim=[0, 1]) &
odl.solvers.CallbackPrintIteration())
odl.solvers.douglas_rachford_pd(x, f, g, lin_ops,
tau=2.0, sigma=[1.0, 0.1],
niter=500, callback=callback)

x.show('Douglas-Rachford Result')
ft.inverse(noisy_data).show('Fourier Inversion Result', force_show=True)
space.show(x, 'TV-regularized Result (Douglas-Rachford)')
space.show(ft.inverse(noisy_data), 'Fourier Inversion Result', force_show=True)
24 changes: 0 additions & 24 deletions examples/space/simple_rn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def element(self, *args, **kwargs):
# Do some tests to compare
n = 10 ** 7
iterations = 10
cuda_supported = 'cuda' in odl.space.entry_points.tensor_space_impl_names()

# Perform some benchmarks with rn
opt_space = odl.rn(n)
Expand All @@ -71,10 +70,6 @@ def element(self, *args, **kwargs):
ox, oy, oz = (opt_space.copy(a) for a in (x, y, z))
sx, sy, sz = (simple_space.copy(a) for a in (x, y, z))

if cuda_supported:
cu_space = odl.rn(n, impl='cuda')
cx, cy, cz = (cu_space.element(a.copy()) for a in (x, y, z))

print(" lincomb:")
with timer("SimpleRn"):
for _ in range(iterations):
Expand All @@ -86,13 +81,6 @@ def element(self, *args, **kwargs):
opt_space.lincomb(2.13, ox, 3.14, oy, out=oz)
print("result: {}".format(oz[1:5]))

if cuda_supported:
with timer("odl cuda"):
for _ in range(iterations):
cu_space.lincomb(2.13, cx, 3.14, cy, out=cz)
print("result: {}".format(cz[1:5]))


print("\n Norm:")
with timer("SimpleRn"):
for _ in range(iterations):
Expand All @@ -104,12 +92,6 @@ def element(self, *args, **kwargs):
result = opt_space.norm(oz)
print("result: {}".format(result))

if cuda_supported:
with timer("odl cuda"):
for _ in range(iterations):
result = cu_space.norm(cz)
print("result: {}".format(result))


print("\n Inner:")
with timer("SimpleRn"):
Expand All @@ -121,9 +103,3 @@ def element(self, *args, **kwargs):
for _ in range(iterations):
result = opt_space.inner(ox, oz)
print("result: {}".format(result))

if cuda_supported:
with timer("odl cuda"):
for _ in range(iterations):
result = cu_space.inner(cx, cz)
print("result: {}".format(result))
5 changes: 3 additions & 2 deletions odl/solvers/functional/default_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,9 @@ def proximal(self):
elif self.exponent == 1:
return proximal_convex_conj_linfty(space=self.domain)
else:
raise NotImplementedError('`proximal` only implemented for p=1, '
'p=2 or p=inf')
raise NotImplementedError(
'`proximal` only implemented for p=2 and p=inf'
)

def __repr__(self):
"""Return ``repr(self)``."""
Expand Down
28 changes: 14 additions & 14 deletions odl/solvers/nonsmooth/proximal_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def convex_conj_prox_factory(sigma):
# prox_factory accepts stepsize objects of the type given by sigma.
space = prox_factory(sigma).domain

mult_inner = MultiplyOperator(space, 1.0 / sigma)
mult_inner = MultiplyOperator(space, 1 / sigma)
mult_outer = MultiplyOperator(space, sigma)
result = (IdentityOperator(space) -
mult_outer * prox_factory(1.0 / sigma) * mult_inner)
mult_outer * prox_factory(1 / sigma) * mult_inner)
return result

return convex_conj_prox_factory
Expand Down Expand Up @@ -292,7 +292,7 @@ def proximal_arg_scaling(prox_factory, scaling):
# unconditionally, but only if the scaling factor is a scalar:
if np.isscalar(scaling):
if scaling == 0:
return proximal_const_func(prox_factory(1.0).domain)
return proximal_const_func(prox_factory(1).domain)
elif scaling.imag != 0:
raise ValueError("Complex scaling not supported.")
else:
Expand Down Expand Up @@ -327,7 +327,7 @@ def arg_scaling_prox_factory(sigma):


def proximal_quadratic_perturbation(prox_factory, a, u=None):
r"""Calculate the proximal of function F(x) + a * \|x\|^2 + <u,x>.
r"""Calculate the proximal of function F(x) + a * ||x||^2 + <u,x>.
Parameters
----------
Expand Down Expand Up @@ -377,8 +377,9 @@ def proximal_quadratic_perturbation(prox_factory, a, u=None):
"""
a = float(a)
if a < 0:
raise ValueError('scaling parameter muts be non-negative, got {}'
''.format(a))
raise ValueError(
'scaling parameter must be non-negative, got {}'.format(a)
)

def quadratic_perturbation_prox_factory(sigma):
r"""Create proximal for the quadratic perturbation with a given sigma.
Expand All @@ -399,7 +400,7 @@ def quadratic_perturbation_prox_factory(sigma):
else:
sigma = np.asarray(sigma)

const = 1.0 / np.sqrt(sigma * 2.0 * a + 1)
const = 1 / np.sqrt(2 * sigma * a + 1)
prox = proximal_arg_scaling(prox_factory, const)(sigma)
space = prox.domain
if u is not None:
Expand Down Expand Up @@ -486,8 +487,7 @@ def proximal_composition_factory(sigma):
Id = IdentityOperator(operator.domain)
Ir = IdentityOperator(operator.range)
prox_muf = proximal(mu * sigma)
return (Id +
(1.0 / mu) * operator.adjoint * ((prox_muf - Ir) * operator))
return Id + (1 / mu) * operator.adjoint * ((prox_muf - Ir) * operator)

return proximal_composition_factory

Expand Down Expand Up @@ -788,7 +788,7 @@ def _call(self, x, out):
else:
step = np.infty

if step < 1.0:
if step < 1:
self.range.lincomb(1 - step, x, out=out)
else:
self.range.lincomb(0, out, out=out)
Expand All @@ -800,7 +800,7 @@ def _call(self, x, out):
else:
step = np.infty

if step < 1.0:
if step < 1:
self.range.lincomb(1 - step, x, step, g, out=out)
else:
self.range.assign(out, g)
Expand Down Expand Up @@ -1012,7 +1012,7 @@ def _call(self, x, out):
space.lincomb(1, x, 1, tmp, out=out)
else:
F.multiply(sig, 2 * lam * g, out=out)
space.lincomb.lincomb(1, x, 1, out, out=out)
space.lincomb(1, x, 1, out, out=out)
F.divide(out, 1 + 2 * sig * lam, out=out)

return ProximalL2Squared
Expand Down Expand Up @@ -1363,7 +1363,7 @@ def _call(self, x, out):
F.divide(diff, denom, out=out)

# out = x - ...
space.lincomb(1, x, -1, out, out=out)
space.lincomb(1, x_old, -1, out, out=out)

return ProximalL1

Expand Down Expand Up @@ -1810,7 +1810,7 @@ def _call(self, x, out):
# out = ... + 4*lam*sigma*g
# If g is None, it is taken as the one element
if g is None:
out += 4.0 * lam * self.sigma
out += 4 * lam * self.sigma
else:
space.lincomb(1, out, 4 * lam * self.sigma, g, out=out)

Expand Down
15 changes: 9 additions & 6 deletions odl/test/operator/operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,10 @@ def test_nonlinear_functional_operators():
assert C(x) == pytest.approx(mat(x / 2.0))


# test functions to dispatch
# Test functions to dispatch
# First doc line is the true signature
# Second doc line contains `has_out` and `out_optional` booleans
# Third doc line indicates whether the signature is OK for Operator._call
def f1(x):
"""f1(x)
False, False
Expand Down Expand Up @@ -898,25 +901,24 @@ def func(request):


def test_function_signature(func):

true_sig = func.__doc__.splitlines()[0].strip()
sig = _function_signature(func)
assert true_sig == sig


def test_dispatch_call_args(func):
# Unbound functions
true_has, true_opt = eval(func.__doc__.splitlines()[1].strip())
true_has_out, true_out_opt = eval(func.__doc__.splitlines()[1].strip())
good = func.__doc__.splitlines()[2].strip() == 'good'

if good:
truespec = getargspec(func)
truespec.args.insert(0, 'self')

has, opt, spec = _dispatch_call_args(unbound_call=func)
has_out, out_opt, spec = _dispatch_call_args(unbound_call=func)

assert has == true_has
assert opt == true_opt
assert has_out == true_has_out
assert out_opt == true_out_opt
assert spec == truespec
else:
with pytest.raises(ValueError):
Expand All @@ -926,6 +928,7 @@ def test_dispatch_call_args(func):
def test_dispatch_call_args_class():

# Two sneaky classes whose _call method would pass the signature check
# because it looks okay from the second argument on
class WithStaticMethod(object):
@staticmethod
def _call(x, y, out):
Expand Down
2 changes: 2 additions & 0 deletions odl/util/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def simple_fixture(name, params, fmt=None):


# Helpers to generate data

# TODO(kohr-h): rename to noise_np_array
def noise_array(space):
"""Generate a white noise array for ``space``.
Expand Down

0 comments on commit 002fc6f

Please sign in to comment.