diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index b0228dee..409ea9c5 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -85,6 +85,7 @@ get_varying_probe, apply_median_filter_abs_probe, orthogonalize_eig, + finite_probe_support, ) logger = logging.getLogger(__name__) @@ -715,6 +716,27 @@ def _apply_probe_constraints( if parameters.probe_options is not None: if parameters.probe_options.recover_probe(epoch): + if parameters.probe_options.probe_support > 0: + b0 = finite_probe_support( + parameters.probe, + p=parameters.probe_options.probe_support, + radius=parameters.probe_options.probe_support_radius, + degree=parameters.probe_options.probe_support_degree, + ) + parameters.probe -= b0 * cp.conj(b0 * parameters.probe) + + if parameters.probe_options.additional_probe_penalty > 0: + b1 = ( + parameters.probe_options.additional_probe_penalty + * cp.linspace( + 0, + 1, + parameters.probe.shape[-3], + dtype=tike.precision.floating, + )[..., None, None] + ) + parameters.probe -= b1 * cp.conj(b1 * parameters.probe) + if parameters.probe_options.median_filter_abs_probe: parameters.probe = apply_median_filter_abs_probe( parameters.probe, diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 3700d123..40b94571 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -710,35 +710,6 @@ def _precondition_nearplane_gradients( A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1)) if recover_probe: - b0 = tike.ptycho.probe.finite_probe_support( - unique_probe[..., m : m + 1, :, :], - p=probe_options.probe_support, - radius=probe_options.probe_support_radius, - degree=probe_options.probe_support_degree, - ) - - b1 = ( - probe_options.additional_probe_penalty - * cp.linspace( - 0, - 1, - probe[0].shape[-3], - dtype=tike.precision.floating, - )[..., m : m + 1, None, None] - ) - - m_probe_update = m_probe_update - (b0 + b1) * probe[..., m : m + 1, :, :] - # / ( - # (1 - alpha) * probe_update_denominator - # + alpha - # * probe_update_denominator.max( - # axis=(-2, -1), - # keepdims=True, - # ) - # + b0 - # + b1 - # ) - dPO = m_probe_update[..., m:m + 1, :, :] * patches A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1)) diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index 878bad82..1f8837ee 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -264,17 +264,7 @@ def _update( psi = psi + dpsi / deno if recover_probe: - b0 = tike.ptycho.probe.finite_probe_support( - probe, - p=probe_options.probe_support, - radius=probe_options.probe_support_radius, - degree=probe_options.probe_support_degree, - ) - b1 = ( - probe_options.additional_probe_penalty - * cp.linspace(0, 1, probe.shape[-3], dtype="float32")[..., None, None] - ) - dprobe = probe_update_numerator - (b1 + b0) * probe + dprobe = probe_update_numerator deno = ( (1 - algorithm_options.alpha) * probe_options.preconditioner + algorithm_options.alpha @@ -282,8 +272,6 @@ def _update( axis=(-2, -1), keepdims=True, ) - + b0 - + b1 ) probe = probe + dprobe / deno if probe_options.use_adaptive_moment: