Skip to content

Commit

Permalink
Merge pull request #326 from carterbox/probe-updates
Browse files Browse the repository at this point in the history
API: Move finite support probe constraints to probe constraints section
  • Loading branch information
a4894z authored Jul 24, 2024
2 parents 4e48e2d + 08450de commit ccd45b9
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 51 deletions.
29 changes: 20 additions & 9 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,18 +809,29 @@ def constrain_center_peak(probe):
stack = probe.reshape((-1, *probe.shape[-2:]))
intensity = cupyx.scipy.ndimage.gaussian_filter(
input=np.sum(np.square(np.abs(stack)), axis=0),
sigma=half,
mode='wrap',
sigma=(half[0] / 3, half[1] / 3),
mode="constant",
cval=0.0,
truncate=6.0,
)
# Find the maximum intensity in 2D.
center = np.argmax(intensity)
# Find the 2D coordinates of the maximum.
coords = cp.unravel_index(center, dims=probe.shape[-2:])
# Shift each of the probes so the max is in the center.
p = np.roll(stack, half[0] - coords[0], axis=-2)
stack = np.roll(p, half[1] - coords[1], axis=-1)
coords = cp.round(cupyx.scipy.ndimage.center_of_mass(intensity))
# Shift each of the probes so the max is in the center. Take integer steps
# only one pixel at a time.
shifted = cupyx.scipy.ndimage.shift(
stack,
shift=(
0,
min(1, max(-1, half[0] - coords[0])),
min(1, max(-1, half[1] - coords[1])),
),
mode="constant",
cval=0.0,
order=0,
)
assert shifted.dtype == stack.dtype, (shifted.dtype, stack.dtype)
# Reform to the original shape; make contiguous.
probe = stack.reshape(probe.shape)
probe = shifted.reshape(probe.shape)
return probe


Expand Down
22 changes: 22 additions & 0 deletions src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
get_varying_probe,
apply_median_filter_abs_probe,
orthogonalize_eig,
finite_probe_support,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -715,6 +716,27 @@ def _apply_probe_constraints(
if parameters.probe_options is not None:
if parameters.probe_options.recover_probe(epoch):

if parameters.probe_options.probe_support > 0:
b0 = finite_probe_support(
parameters.probe,
p=parameters.probe_options.probe_support,
radius=parameters.probe_options.probe_support_radius,
degree=parameters.probe_options.probe_support_degree,
)
parameters.probe -= b0 * cp.conj(b0 * parameters.probe)

if parameters.probe_options.additional_probe_penalty > 0:
b1 = (
parameters.probe_options.additional_probe_penalty
* cp.linspace(
0,
1,
parameters.probe.shape[-3],
dtype=tike.precision.floating,
)[..., None, None]
)
parameters.probe -= b1 * cp.conj(b1 * parameters.probe)

if parameters.probe_options.median_filter_abs_probe:
parameters.probe = apply_median_filter_abs_probe(
parameters.probe,
Expand Down
29 changes: 0 additions & 29 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,35 +710,6 @@ def _precondition_nearplane_gradients(
A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1))

if recover_probe:
b0 = tike.ptycho.probe.finite_probe_support(
unique_probe[..., m : m + 1, :, :],
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)

b1 = (
probe_options.additional_probe_penalty
* cp.linspace(
0,
1,
probe[0].shape[-3],
dtype=tike.precision.floating,
)[..., m : m + 1, None, None]
)

m_probe_update = m_probe_update - (b0 + b1) * probe[..., m : m + 1, :, :]
# / (
# (1 - alpha) * probe_update_denominator
# + alpha
# * probe_update_denominator.max(
# axis=(-2, -1),
# keepdims=True,
# )
# + b0
# + b1
# )

dPO = m_probe_update[..., m:m + 1, :, :] * patches
A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1))

Expand Down
14 changes: 1 addition & 13 deletions src/tike/ptycho/solvers/rpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,26 +264,14 @@ def _update(
psi = psi + dpsi / deno

if recover_probe:
b0 = tike.ptycho.probe.finite_probe_support(
probe,
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)
b1 = (
probe_options.additional_probe_penalty
* cp.linspace(0, 1, probe.shape[-3], dtype="float32")[..., None, None]
)
dprobe = probe_update_numerator - (b1 + b0) * probe
dprobe = probe_update_numerator
deno = (
(1 - algorithm_options.alpha) * probe_options.preconditioner
+ algorithm_options.alpha
* probe_options.preconditioner.max(
axis=(-2, -1),
keepdims=True,
)
+ b0
+ b1
)
probe = probe + dprobe / deno
if probe_options.use_adaptive_moment:
Expand Down
14 changes: 14 additions & 0 deletions tests/ptycho/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,20 @@ def test_hermite_modes():
np.rollaxis(inputs['result'], -1, 0)[None, ...],
)

def test_center_peak():

x = cp.ones((1, 1, 1, 7, 7), dtype=cp.complex64)

x[0,0,0, 3, 6] = 10 + 23j

print()
print(x.squeeze())

y = tike.ptycho.probe.constrain_center_peak(x)

print()
print(np.round(y.squeeze(), 1))


if __name__ == '__main__':
unittest.main()

0 comments on commit ccd45b9

Please sign in to comment.