diff --git a/src/tike/cluster.py b/src/tike/cluster.py index 484fe053..de4f2268 100644 --- a/src/tike/cluster.py +++ b/src/tike/cluster.py @@ -17,6 +17,7 @@ def _split_gpu( x: npt.ArrayLike, dtype: npt.DTypeLike, ) -> npt.ArrayLike: + """Return x[m] as a CuPy array on the current device.""" return cp.asarray(x[m], dtype=dtype) @@ -25,6 +26,7 @@ def _split_host( x: npt.ArrayLike, dtype: npt.DTypeLike, ) -> npt.ArrayLike: + """Return x[m] as a NumPy array.""" return np.asarray(x[m], dtype=dtype) @@ -33,6 +35,7 @@ def _split_pinned( x: npt.ArrayLike, dtype: npt.DTypeLike, ) -> npt.ArrayLike: + """Return x[m] as a CuPy pinned host memory array.""" pinned = cupyx.empty_pinned(shape=(len(m), *x.shape[1:]), dtype=dtype) pinned[...] = x[m] return pinned @@ -171,22 +174,24 @@ def by_scan_stripes( def by_scan_stripes_contiguous( - *args, pool: tike.communicators.ThreadPool, shape: typing.Tuple[int], - dtype: typing.List[npt.DTypeLike], - destination: typing.List[str], scan: npt.NDArray[np.float32], - fly: int = 1, - batch_method, + batch_method: typing.Literal[ + "compact", "wobbly_center", "wobbly_center_random_bootstrap" + ], num_batch: int, -) -> typing.Tuple[typing.List[npt.NDArray], - typing.List[typing.List[npt.NDArray]]]: - """Split data by into stripes and create contiguously ordered batches. - - Divide the field of view into one stripe per devices; within each stripe, - create batches according to the batch_method loading the batches into - contiguous blocks in device memory. +) -> typing.Tuple[ + typing.List[npt.NDArray], + typing.List[typing.List[npt.NDArray]], + typing.List[int], +]: + """Return the indices that will split `scan` into 2D stripes of equal count + and create contiguously ordered batches within those stripes. + + Divide the field of view into one stripe per worker in `pool`; within each + stripe, create batches according to the batch_method loading the batches + into contiguous blocks in device memory. Parameters ---------- @@ -206,14 +211,14 @@ def by_scan_stripes_contiguous( Returns ------- order : List[array[int]] - The locations of the inputs in the original arrays. + For each worker in pool, the indices of the data batches : List[List[array[int]]] - The locations of the elements of each batch - scan : List[array[float32]] - The divided 2D coordinates of the scan positions. - args : List[array[float32]] or None - Each input divided into regions or None if arg was None. - + For each worker in pool, for each batch, the indices of the elements of + each batch + stripe_start : List[int] + The coorinates of the leading edge of each stripe along the 0th + dimension in the scan coordinates. e.g the minimum coordinate of the + scan positions in each stripe. """ if len(shape) != 2: raise ValueError('The grid shape must have two dimensions.') @@ -229,6 +234,7 @@ def by_scan_stripes_contiguous( x=scan, dtype=scan.dtype, ) + stripe_start = [int(np.floor(np.min(x[:, 0]))) for x in split_scan] batches_noncontiguous: typing.List[typing.List[npt.NDArray]] = pool.map( getattr(tike.cluster, batch_method), split_scan, @@ -247,26 +253,13 @@ def by_scan_stripes_contiguous( batch_breaks, )) - split_args = [] - for arg, t, dest in zip([scan, *args], dtype, destination): - if arg is None: - split_args.append(None) - else: - split_args.append( - pool.map( - _split_gpu if dest == 'gpu' else _split_pinned, - map_to_gpu_contiguous, - x=arg, - dtype=t, - )) - if __debug__: for device in batches_contiguous: assert len(device) == num_batch, ( f"There should be {num_batch} batches, found {len(device)}" ) - return (map_to_gpu_contiguous, batches_contiguous, *split_args) + return (map_to_gpu_contiguous, batches_contiguous, stripe_start) def stripes_equal_count( @@ -306,7 +299,10 @@ def stripes_equal_count( ) -def wobbly_center(population, num_cluster): +def wobbly_center( + population: npt.ArrayLike, + num_cluster: int, +) -> typing.List[npt.NDArray]: """Return the indices that divide population into heterogenous clusters. Uses a contrarian approach to clustering by maximizing the heterogeneity @@ -382,7 +378,7 @@ def wobbly_center(population, num_cluster): def wobbly_center_random_bootstrap( - population, + population: npt.ArrayLike, num_cluster: int, boot_fraction: float = 0.95, ) -> typing.List[npt.NDArray]: @@ -466,7 +462,11 @@ def wobbly_center_random_bootstrap( return [cp.asnumpy(xp.flatnonzero(labels == c)) for c in range(num_cluster)] -def compact(population, num_cluster, max_iter=500): +def compact( + population: npt.ArrayLike, + num_cluster: int, + max_iter: int = 500, +) -> typing.List[npt.NDArray]: """Return the indices that divide population into compact clusters. Uses an approach that is inspired by the naive k-means algorithm, but it diff --git a/src/tike/communicators/comm.py b/src/tike/communicators/comm.py index c80d7b01..06a2387f 100644 --- a/src/tike/communicators/comm.py +++ b/src/tike/communicators/comm.py @@ -13,10 +13,6 @@ from .pool import ThreadPool -def _init_streams(): - return [cp.cuda.Stream() for _ in range(2)] - - class Comm: """A Ptychography communicator. @@ -47,7 +43,6 @@ def __init__( self.use_mpi = True self.mpi = mpi() self.pool = pool(gpu_count) - self.streams = self.pool.map(_init_streams) def __enter__(self): self.mpi.__enter__() @@ -139,3 +134,22 @@ def Allreduce( buf.append( self.mpi.Allreduce(src[self.pool.workers.index(worker)])) return buf + + def swap_edges( + self, + x: typing.List[cp.ndarray], + overlap: int, + edges: typing.List[int], + ) -> typing.List[cp.ndarray]: + """Swap the region of each x with its neighbor around the given edges. + + Given iterable x, a list of ND arrays; edges, the coordinates in x + along dimension -2; and overlap, the width of the region to swap around + the edge; trade [..., edge:(edge + overlap), :] between neighbors. + """ + # FIXME: Swap edges between MPI nodes + return self.pool.swap_edges( + x=x, + overlap=overlap, + edges=edges, + ) diff --git a/src/tike/communicators/pool.py b/src/tike/communicators/pool.py index 4a67cd4d..78ec4056 100644 --- a/src/tike/communicators/pool.py +++ b/src/tike/communicators/pool.py @@ -89,6 +89,12 @@ def __init__( self.num_workers) if self.num_workers > 1 else NoPoolExecutor( self.num_workers) + def f(worker): + with self.Device(worker): + return [cp.cuda.Stream() for _ in range(2)] + + self.streams = list(self.executor.map(f, self.workers)) + def __enter__(self): if self.workers[0] != cp.cuda.Device().id: raise ValueError( @@ -397,10 +403,74 @@ def map( ) -> list: """ThreadPoolExecutor.map, but wraps call in a cuda.Device context.""" - def f(worker, *args): + def f(worker, streams, *args): with self.Device(worker): - return func(*args, **kwargs) + with streams[1]: + return func(*args, **kwargs) workers = self.workers if workers is None else workers - return list(self.executor.map(f, workers, *iterables)) + return list(self.executor.map(f, workers, self.streams, *iterables)) + + def swap_edges( + self, + x: typing.List[cp.ndarray], + overlap: int, + edges: typing.List[int], + ): + """Swap [..., edge:(edge + overlap), :] between neighbors in-place + + For example, given overlap=3 and edges=[0, 4, 8, 12], the following + swap would be returned: + + ``` + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + + [[0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0]] + [[1 1 1 1 0 0 0 1 2 2 2 1 1 1 1 1]] + [[2 2 2 2 2 2 2 2 1 1 1 2 3 3 3 2]] + [[3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 3]] + ``` + + Note that the minimum swapped region is 1 wide. + + """ + if overlap < 1: + msg = f"Overlap for swap_edges cannot be less than 1: {overlap}" + raise ValueError(msg) + for i in range(self.num_workers - 1): + lo = edges[i + 1] + hi = lo + overlap + temp0 = self._copy_to(x[i][..., lo:hi, :], self.workers[i + 1]) + temp1 = self._copy_to(x[i + 1][..., lo:hi, :], self.workers[i]) + with self.Device(self.workers[i]): + rampu = cp.linspace( + 0.0, + 1.0, + overlap + 2, + endpoint=True, + )[1:-1][..., None] + rampd = cp.linspace( + 1.0, + 0.0, + overlap + 2, + endpoint=True, + )[1:-1][..., None] + x[i][..., lo:hi, :] = rampd * x[i][..., lo:hi, :] + rampu * temp1 + with self.Device(self.workers[i + 1]): + rampu = cp.linspace( + 0.0, + 1.0, + overlap + 2, + endpoint=True, + )[1:-1][..., None] + rampd = cp.linspace( + 1.0, + 0.0, + overlap + 2, + endpoint=True, + )[1:-1][..., None] + x[i + 1][..., lo:hi, :] = ( + rampd * temp0 + rampu * x[i + 1][..., lo:hi, :] + ) + return x diff --git a/src/tike/opt.py b/src/tike/opt.py index ac063e8e..34b7961f 100644 --- a/src/tike/opt.py +++ b/src/tike/opt.py @@ -381,9 +381,9 @@ def conjugate_gradient( def fit_line_least_squares( - y: typing.List[float], - x: typing.List[float], -) -> typing.Tuple[float, float]: + y: npt.NDArray[np.floating], + x: npt.NDArray[np.floating], +) -> typing.Tuple[np.floating, np.floating]: """Return the `slope`, `intercept` pair that best fits `y`, `x` to a line. y = slope * x + intercept diff --git a/src/tike/ptycho/exitwave.py b/src/tike/ptycho/exitwave.py index eec3d0b3..ce0fd3e0 100644 --- a/src/tike/ptycho/exitwave.py +++ b/src/tike/ptycho/exitwave.py @@ -5,6 +5,7 @@ just free space propagation to the detector. """ + from __future__ import annotations import copy @@ -67,7 +68,7 @@ class ExitWaveOptions: exitwave updates in Fourier space. `1.0` for no scaling. """ - propagation_normalization: str = 'ortho' + propagation_normalization: str = "ortho" """Choose the scaling of the FFT in the forward model: "ortho" - the forward and adjoint operations are divided by sqrt(n) @@ -78,19 +79,29 @@ class ExitWaveOptions: """ - def copy_to_device(self, comm) -> ExitWaveOptions: + def copy_to_device(self) -> ExitWaveOptions: """Copy to the current GPU memory.""" - options = copy.copy(self) - if self.measured_pixels is not None: - options.measured_pixels = comm.pool.bcast([self.measured_pixels]) - return options + return ExitWaveOptions( + measured_pixels=cp.asarray(self.measured_pixels, dtype=bool), + noise_model=self.noise_model, + propagation_normalization=self.propagation_normalization, + step_length_start=self.step_length_start, + step_length_usemodes=self.step_length_usemodes, + step_length_weight=self.step_length_weight, + unmeasured_pixels_scaling=self.unmeasured_pixels_scaling, + ) def copy_to_host(self) -> ExitWaveOptions: """Copy to the host CPU memory.""" - options = copy.copy(self) - if self.measured_pixels is not None: - options.measured_pixels = cp.asnumpy(self.measured_pixels[0]) - return options + return ExitWaveOptions( + measured_pixels=cp.asnumpy(self.measured_pixels), + noise_model=self.noise_model, + propagation_normalization=self.propagation_normalization, + step_length_start=self.step_length_start, + step_length_usemodes=self.step_length_usemodes, + step_length_weight=self.step_length_weight, + unmeasured_pixels_scaling=self.unmeasured_pixels_scaling, + ) def resample(self, factor: float) -> ExitWaveOptions: """Return a new `ExitWaveOptions` with the parameters rescaled.""" @@ -103,8 +114,9 @@ def resample(self, factor: float) -> ExitWaveOptions: self.measured_pixels, int(self.measured_pixels.shape[-1] * factor), ), - unmeasured_pixels_scaling=self.unmeasured_pixels_scaling, - propagation_normalization=self.propagation_normalization ) + unmeasured_pixels_scaling=self.unmeasured_pixels_scaling, + propagation_normalization=self.propagation_normalization, + ) def poisson_steplength_all_modes( diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 1ca8f643..6677b0f3 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -71,37 +71,56 @@ class ObjectOptions: ) """The magnitude of the illumination used for conditioning the object updates.""" - combined_update: typing.Union[npt.NDArray, None] = dataclasses.field( - init=False, - default_factory=lambda: None, - ) - """Used for compact batch updates.""" - clip_magnitude: bool = False """Whether to force the object magnitude to remain <= 1.""" - def copy_to_device(self, comm) -> ObjectOptions: + def copy_to_device(self) -> ObjectOptions: """Copy to the current GPU memory.""" - options = copy.copy(self) + options = ObjectOptions( + convergence_tolerance=self.convergence_tolerance, + positivity_constraint=self.positivity_constraint, + smoothness_constraint=self.smoothness_constraint, + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + clip_magnitude=self.clip_magnitude, + ) options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: - options.v = cp.asarray(self.v) + options.v = cp.asarray( + self.v, + dtype=tike.precision.floating, + ) if self.m is not None: - options.m = cp.asarray(self.m) + options.m = cp.asarray( + self.m, + dtype=tike.precision.floating, + ) if self.preconditioner is not None: - options.preconditioner = comm.pool.bcast([self.preconditioner]) + options.preconditioner = cp.asarray( + self.preconditioner, + dtype=tike.precision.cfloating, + ) return options def copy_to_host(self) -> ObjectOptions: """Copy to the host CPU memory.""" - options = copy.copy(self) + options = ObjectOptions( + convergence_tolerance=self.convergence_tolerance, + positivity_constraint=self.positivity_constraint, + smoothness_constraint=self.smoothness_constraint, + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + clip_magnitude=self.clip_magnitude, + ) options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: options.v = cp.asnumpy(self.v) if self.m is not None: options.m = cp.asnumpy(self.m) if self.preconditioner is not None: - options.preconditioner = cp.asnumpy(self.preconditioner[0]) + options.preconditioner = cp.asnumpy(self.preconditioner) return options def resample(self, factor: float, interp) -> ObjectOptions: @@ -119,6 +138,58 @@ def resample(self, factor: float, interp) -> ObjectOptions: return options # Momentum reset to zero when grid scale changes + @staticmethod + def join_psi( + x: typing.List[np.ndarray], + stripe_start: typing.List[int], + probe_width: int, + ) -> np.ndarray: + """Recombine `x`, a list of psi, from a split reconstruction.""" + joined_psi = x[0] + w = probe_width // 2 + for i in range(1, len(x)): + lo = stripe_start[i] + w + hi = stripe_start[i + 1] + w if i + 1 < len(x) else x[0].shape[1] + joined_psi[:, lo:hi, :] = x[i][:, lo:hi, :] + return joined_psi + + @staticmethod + def join( + x: typing.List[ObjectOptions], + stripe_start: typing.List[int], + probe_width: int, + ) -> ObjectOptions: + """Recombine `x`, a list of ObjectOptions, from split ObjectOptions""" + options = ObjectOptions( + convergence_tolerance=x[0].convergence_tolerance, + positivity_constraint=x[0].positivity_constraint, + smoothness_constraint=x[0].smoothness_constraint, + use_adaptive_moment=x[0].use_adaptive_moment, + vdecay=x[0].vdecay, + mdecay=x[0].mdecay, + clip_magnitude=x[0].clip_magnitude, + ) + options.update_mnorm = copy.copy(x[0].update_mnorm) + if x[0].v is not None: + options.v = ObjectOptions.join_psi( + [e.v for e in x], + stripe_start, + probe_width, + ) + if x[0].m is not None: + options.m = ObjectOptions.join_psi( + [e.m for e in x], + stripe_start, + probe_width, + ) + if x[0].preconditioner is not None: + options.preconditioner = ObjectOptions.join_psi( + [e.preconditioner for e in x], + stripe_start, + probe_width, + ) + return options + def positivity_constraint(x, r): """Constrains the amplitude of x to be positive with sum of abs(x) and x. diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 0f02f835..d9766712 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -124,6 +124,7 @@ import cupy as cp import cupyx.scipy.ndimage import numpy as np +import numpy.typing as npt import tike.communicators import tike.linalg @@ -156,7 +157,15 @@ def resample(self, factor: float) -> AffineTransform: ) @classmethod - def fromarray(self, T: np.ndarray) -> AffineTransform: + def frombuffer(cls, buffer: np.ndarray) -> AffineTransform: + return AffineTransform(*buffer) + + def asbuffer(self) -> np.ndarray: + """Return the constructor parameters in a tuple.""" + return np.array(self.astuple()) + + @classmethod + def fromarray(cls, T: np.ndarray) -> AffineTransform: """Return an Affine Transfrom from a 2x2 matrix. Use decomposition method from Graphics Gems 2 Section 7.1 @@ -180,8 +189,8 @@ def fromarray(self, T: np.ndarray) -> AffineTransform: scale1=float(scale1), shear1=float(shear1), angle=float(angle), - t0=T[2, 0] if T.shape[0] > 2 else 0, - t1=T[2, 1] if T.shape[0] > 2 else 0, + t0=float(T[2, 0] if T.shape[0] > 2 else 0), + t1=float(T[2, 1] if T.shape[0] > 2 else 0), ) def asarray(self, xp=np) -> np.ndarray: @@ -347,7 +356,10 @@ class PositionOptions: transform: AffineTransform = AffineTransform() """Global transform of positions.""" - origin: tuple[float, float] = (0, 0) + origin: npt.NDArray = dataclasses.field( + init=True, + default_factory=lambda: np.zeros(2), + ) """The rotation center of the transformation. This shift is applied to the scan positions before computing the global transformation.""" @@ -360,6 +372,11 @@ class PositionOptions: update_start: int = 0 """Start position updates at this epoch.""" + _momentum: np.ndarray = dataclasses.field( + init=False, + default_factory=lambda: None, + ) + def __post_init__(self): self.initial_scan = self.initial_scan.astype(tike.precision.floating) if self.confidence is None: @@ -413,7 +430,7 @@ def empty(self): new._momentum = np.empty((0, 4)) return new - def split(self, indices): + def split(self, indices: npt.NDArray[np.intc]) -> PositionOptions: """Split the PositionOption meta-data along indices.""" new = PositionOptions( self.initial_scan[..., indices, :], @@ -439,49 +456,76 @@ def insert(self, other, indices): self._momentum[..., indices, :] = other._momentum return self - def join(self, other, indices): - """Replace the PositionOption meta-data with other data.""" - len_scan = self.initial_scan.shape[-2] - max_index = max(indices.max() + 1, len_scan) - new_initial_scan = np.empty( - (*self.initial_scan.shape[:-2], max_index, 2), - dtype=self.initial_scan.dtype, + @staticmethod + def join( + x: typing.Iterable[PositionOptions | None], + reorder: npt.NDArray[np.intc], + ) -> PositionOptions | None: + if None in x: + return None + new = PositionOptions( + initial_scan=np.concatenate( + [e.initial_scan for e in x], + axis=0, + )[reorder], + use_adaptive_moment=x[0].use_adaptive_moment, + vdecay=x[0].vdecay, + mdecay=x[0].mdecay, + use_position_regularization=x[0].use_position_regularization, + update_magnitude_limit=x[0].update_magnitude_limit, + transform=x[0].transform, ) - new_initial_scan[..., :len_scan, :] = self.initial_scan - new_initial_scan[..., indices, :] = other.initial_scan - self.initial_scan = new_initial_scan - if self.confidence is not None: - new_confidence = np.empty( - (*self.initial_scan.shape[:-2], max_index, 2), - dtype=self.initial_scan.dtype, - ) - new_confidence[..., :len_scan, :] = self.confidence - new_confidence[..., indices, :] = other.confidence - self.confidence = new_confidence - if self.use_adaptive_moment: - new_momentum = np.empty( - (*self.initial_scan.shape[:-2], max_index, 4), - dtype=self.initial_scan.dtype, - ) - new_momentum[..., :len_scan, :] = self._momentum - new_momentum[..., indices, :] = other._momentum - self._momentum = new_momentum - return self + if x[0].confidence is not None: + new.confidence = np.concatenate( + [e.confidence for e in x], + axis=0, + )[reorder] + + if x[0].use_adaptive_moment: + new._momentum = np.concatenate( + [e._momentum for e in x], + axis=0, + )[reorder] + + return new def copy_to_device(self): """Copy to the current GPU memory.""" - options = copy.copy(self) - options.initial_scan = cp.asarray(self.initial_scan) + options = PositionOptions( + initial_scan=cp.asarray(self.initial_scan), + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + use_position_regularization=self.use_position_regularization, + update_magnitude_limit=self.update_magnitude_limit, + transform=self.transform, + confidence=self.confidence, + update_start=self.update_start, + origin=cp.asarray(self.origin), + ) if self.confidence is not None: options.confidence = cp.asarray(self.confidence) if self.use_adaptive_moment: - options._momentum = cp.asarray(self._momentum) + options._momentum = cp.asarray( + self._momentum, + dtype=tike.precision.floating, + ) return options def copy_to_host(self): """Copy to the host CPU memory.""" - options = copy.copy(self) - options.initial_scan = cp.asnumpy(self.initial_scan) + options = PositionOptions( + initial_scan=cp.asnumpy(self.initial_scan), + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + use_position_regularization=self.use_position_regularization, + update_magnitude_limit=self.update_magnitude_limit, + transform=self.transform, + confidence=self.confidence, + update_start=self.update_start, + origin=cp.asnumpy(self.origin), + ) if self.confidence is not None: options.confidence = cp.asnumpy(self.confidence) if self.use_adaptive_moment: @@ -499,6 +543,8 @@ def resample(self, factor: float) -> PositionOptions: update_magnitude_limit=self.update_magnitude_limit, transform=self.transform.resample(factor), confidence=self.confidence, + update_start=self.update_start, + origin=self.origin * factor, ) # Momentum reset to zero when grid scale changes return new @@ -569,8 +615,12 @@ def check_allowed_positions(scan: np.array, psi: np.array, probe_shape: tuple): valid_min_corner = (1, 1) valid_max_corner = (psi.shape[-2] - probe_shape[-2] - 1, psi.shape[-1] - probe_shape[-1] - 1) - if (np.any(min_corner < valid_min_corner) - or np.any(max_corner > valid_max_corner)): + if ( + min_corner[0] < valid_min_corner[0] + or min_corner[1] < valid_min_corner[1] + or max_corner[0] > valid_max_corner[0] + or max_corner[1] > valid_max_corner[1] + ): raise ValueError( "Scan positions must be >= 1 and " "scan positions + 1 + probe.shape must be <= psi.shape. " @@ -680,12 +730,10 @@ def _affine_position_helper( # TODO: What is a good default value for max_error? def affine_position_regularization( - comm: tike.communicators.Comm, - updated: typing.List[cp.ndarray], - position_options: typing.List[PositionOptions], + updated: cp.ndarray, + position_options: PositionOptions, max_error: float = 32, - regularization_enabled: bool = False, -) -> typing.Tuple[typing.List[cp.ndarray], typing.List[PositionOptions]]: +) -> typing.Tuple[cp.ndarray, PositionOptions]: """Regularize position updates with an affine deformation constraint. Assume that the true position updates are a global affine transformation @@ -707,30 +755,20 @@ def affine_position_regularization( """ # Gather all of the scanning positions on one host - positions0 = comm.pool.gather_host( - [x.initial_scan for x in position_options], axis=0) - positions1 = comm.pool.gather_host(updated, axis=0) - positions0 = comm.mpi.Gather(positions0, axis=0, root=0) - positions1 = comm.mpi.Gather(positions1, axis=0, root=0) - - if comm.mpi.rank == 0: - new_transform, _ = estimate_global_transformation_ransac( - positions0=positions0 - position_options[0].origin, - positions1=positions1 - position_options[0].origin, - transform=position_options[0].transform, - max_error=max_error, - ) - else: - new_transform = None - - new_transform = comm.mpi.bcast(new_transform, root=0) + positions0 = position_options.initial_scan + positions1 = updated + + new_transform, _ = estimate_global_transformation_ransac( + positions0=positions0 - position_options.origin, + positions1=positions1 - position_options.origin, + transform=position_options.transform, + max_error=max_error, + ) - for i in range(len(position_options)): - position_options[i].transform = new_transform + position_options.transform = new_transform - if regularization_enabled: - updated = comm.pool.map( - _affine_position_helper, + if position_options.use_position_regularization: + updated = _affine_position_helper( updated, position_options, max_error=max_error, diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index a5d2e847..91a563a0 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -36,7 +36,6 @@ """ from __future__ import annotations -import copy import dataclasses import logging import typing @@ -57,9 +56,6 @@ class ProbeOptions: """Manage data and setting related to probe correction.""" - recover_probe: bool = False - """Boolean switch used to indicate whether to update probe or not.""" - update_start: int = 0 """Start probe updates at this epoch.""" @@ -139,19 +135,13 @@ class ProbeOptions: """ median_filter_abs_probe: bool = False - """Binary switch on whether to apply a median filter to absolute value of + """Binary switch on whether to apply a median filter to absolute value of each shared probe mode. """ median_filter_abs_probe_px: typing.Tuple[float, float] = ( 1.0, 1.0 ) """A 2-element tuple with the median filter pixel widths along each dimension.""" - probe_update_sum: typing.Union[npt.NDArray, None] = dataclasses.field( - init=False, - default_factory=lambda: None, - ) - """Used for momentum updates.""" - preconditioner: typing.Union[npt.NDArray, None] = dataclasses.field( init=False, default_factory=lambda: None, @@ -163,32 +153,80 @@ class ProbeOptions: ) """The power of the primary probe modes at each iteration.""" - def copy_to_device(self, comm) -> ProbeOptions: + def recover_probe(self, epoch: int) -> bool: + """Return whether to update probe or not.""" + return (epoch >= self.update_start) and (epoch % self.update_period == 0) + + def copy_to_device(self) -> ProbeOptions: """Copy to the current GPU memory.""" - options = copy.copy(self) + options = ProbeOptions( + update_start=self.update_start, + update_period=self.update_period, + init_rescale_from_measurements=self.init_rescale_from_measurements, + probe_photons=self.probe_photons, + force_orthogonality=self.force_orthogonality, + force_centered_intensity=self.force_centered_intensity, + force_sparsity=self.force_sparsity, + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + probe_support=self.probe_support, + probe_support_radius=self.probe_support_radius, + probe_support_degree=self.probe_support_degree, + additional_probe_penalty=self.additional_probe_penalty, + median_filter_abs_probe=self.median_filter_abs_probe, + median_filter_abs_probe_px=self.median_filter_abs_probe_px, + ) + options.power=self.power if self.v is not None: - options.v = cp.asarray(self.v) + options.v = cp.asarray( + self.v, + dtype=tike.precision.floating, + ) if self.m is not None: - options.m = cp.asarray(self.m) + options.m = cp.asarray( + self.m, + dtype=tike.precision.floating, + ) if self.preconditioner is not None: - options.preconditioner = comm.pool.bcast([self.preconditioner]) + options.preconditioner = cp.asarray( + self.preconditioner, + dtype=tike.precision.cfloating, + ) return options def copy_to_host(self) -> ProbeOptions: """Copy to the host CPU memory.""" - options = copy.copy(self) + options = ProbeOptions( + update_start=self.update_start, + update_period=self.update_period, + init_rescale_from_measurements=self.init_rescale_from_measurements, + probe_photons=self.probe_photons, + force_orthogonality=self.force_orthogonality, + force_centered_intensity=self.force_centered_intensity, + force_sparsity=self.force_sparsity, + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + probe_support=self.probe_support, + probe_support_radius=self.probe_support_radius, + probe_support_degree=self.probe_support_degree, + additional_probe_penalty=self.additional_probe_penalty, + median_filter_abs_probe=self.median_filter_abs_probe, + median_filter_abs_probe_px=self.median_filter_abs_probe_px, + ) + options.power=self.power if self.v is not None: options.v = cp.asnumpy(self.v) if self.m is not None: options.m = cp.asnumpy(self.m) if self.preconditioner is not None: - options.preconditioner = cp.asnumpy(self.preconditioner[0]) + options.preconditioner = cp.asnumpy(self.preconditioner) return options def resample(self, factor: float, interp) -> ProbeOptions: """Return a new `ProbeOptions` with the parameters rescaled.""" options = ProbeOptions( - recover_probe=self.recover_probe, update_start=self.update_start, update_period=self.update_period, init_rescale_from_measurements=self.init_rescale_from_measurements, @@ -200,15 +238,16 @@ def resample(self, factor: float, interp) -> ProbeOptions: vdecay=self.vdecay, mdecay=self.mdecay, probe_support=self.probe_support, - probe_support_degree=self.probe_support_degree, probe_support_radius=self.probe_support_radius, - median_filter_abs_probe=self.median_filter_abs_probe, + probe_support_degree=self.probe_support_degree, + additional_probe_penalty=self.additional_probe_penalty, + median_filter_abs_probe=self.median_filter_abs_probe, median_filter_abs_probe_px=self.median_filter_abs_probe_px, ) + options.power=self.power return options # Momentum reset to zero when grid scale changes - def get_varying_probe(shared_probe, eigen_probe=None, weights=None): """Construct the varying probes. @@ -243,8 +282,19 @@ def get_varying_probe(shared_probe, eigen_probe=None, weights=None): return shared_probe.copy() -def _constrain_variable_probe1(variable_probe, weights): - """Help use the thread pool with constrain_variable_probe""" +def constrain_variable_probe(variable_probe, weights): + """Add the following constraints to variable probe weights + + 1. Remove outliars from weights + 2. Enforce orthogonality once per epoch + 3. Sort the variable probes by their total energy + 4. Normalize the variable probes so the energy is contained in the weight + + """ + # TODO: No smoothing of variable probe weights yet because the weights are + # not stored consecutively in device memory. Smoothing would require either + # sorting and synchronizing the weights with the host OR implementing + # smoothing of non-gridded data with splines using device-local data only. # Normalize variable probes vnorm = tike.linalg.mnorm(variable_probe, axis=(-2, -1), keepdims=True) @@ -266,12 +316,6 @@ def _constrain_variable_probe1(variable_probe, weights): axis=-3, )**2 - return variable_probe, weights, power - - -def _constrain_variable_probe2(variable_probe, weights, power): - """Help use the thread pool with constrain_variable_probe""" - # Sort the probes by energy probes_with_modes = variable_probe.shape[-3] for i in range(probes_with_modes): @@ -294,124 +338,7 @@ def _constrain_variable_probe2(variable_probe, weights, power): return variable_probe, weights -def constrain_variable_probe(comm, variable_probe, weights): - """Add the following constraints to variable probe weights - - 1. Remove outliars from weights - 2. Enforce orthogonality once per epoch - 3. Sort the variable probes by their total energy - 4. Normalize the variable probes so the energy is contained in the weight - - """ - # TODO: No smoothing of variable probe weights yet because the weights are - # not stored consecutively in device memory. Smoothing would require either - # sorting and synchronizing the weights with the host OR implementing - # smoothing of non-gridded data with splines using device-local data only. - - variable_probe, weights, power = zip(*comm.pool.map( - _constrain_variable_probe1, - variable_probe, - weights, - )) - - # reduce power by sum across all devices - power = comm.pool.allreduce(power) - - variable_probe, weights = (list(a) for a in zip(*comm.pool.map( - _constrain_variable_probe2, - variable_probe, - weights, - power, - ))) - - return variable_probe, weights - - -def _get_update(R, eigen_probe, weights, batches, *, batch_index, c, m): - """ - Parameters - ---------- - R : (B, 1, 1, H, W) - eigen_probe (1, C, M, H, W) - weights : (B, C, M) - """ - lo = batches[batch_index][0] - hi = lo + len(batches[batch_index]) - # (POSI, 1, 1, 1, 1) to match other arrays - weights = weights[lo:hi, c:c + 1, m:m + 1, None, None] - eigen_probe = eigen_probe[:, c - 1:c, m:m + 1, :, :] - norm_weights = tike.linalg.norm(weights, axis=-5, keepdims=True)**2 - - if np.all(norm_weights == 0): - raise ValueError('eigen_probe weights cannot all be zero?') - - # FIXME: What happens when weights is zero!? - proj = (np.real(R.conj() * eigen_probe) + weights) / norm_weights - return np.mean( - R * np.mean(proj, axis=(-2, -1), keepdims=True), - axis=-5, - keepdims=True, - ) - - -def _get_d(patches, diff, eigen_probe, update, *, β, c, m): - """ - Parameters - ---------- - patches : (B, 1, 1, H, W) - diff : (B, 1, M, H, W) - eigen_probe (1, C, M, H, W) - update : (1, 1, 1, H, W) - """ - eigen_probe[:, c - 1:c, m:m + 1, :, :] += β * update / tike.linalg.mnorm( - update, - axis=(-2, -1), - keepdims=True, - ) - eigen_probe[:, c - 1:c, m:m + 1, :, :] /= tike.linalg.mnorm( - eigen_probe[:, c - 1:c, m:m + 1, :, :], - axis=(-2, -1), - keepdims=True, - ) - assert np.all(np.isfinite(eigen_probe)) - - # Determine new eigen_weights for the updated eigen probe - phi = patches * eigen_probe[:, c - 1:c, m:m + 1, :, :] - n = np.mean( - np.real(diff[:, :, m:m + 1, :, :] * phi.conj()), - axis=(-1, -2), - keepdims=False, - ) - d = np.mean(np.square(np.abs(phi)), axis=(-1, -2), keepdims=False) - d_mean = np.mean(d, axis=-3, keepdims=True) - return eigen_probe, n, d, d_mean - - -def _get_weights_mean(n, d, d_mean, weights, batches, *, batch_index, c, m): - """ - Parameters - ---------- - n : (B, 1, 1) - d : (B, 1, 1) - d_mean : (1, 1, 1) - weights : (B, C, M) - """ - lo = batches[batch_index][0] - hi = lo + len(batches[batch_index]) - # yapf: disable - weight_update = ( - n / (d + 0.1 * d_mean) - ).reshape(*weights[lo:hi, c:c + 1, m:m + 1].shape) - # yapf: enable - assert np.all(np.isfinite(weight_update)) - - # (33) The sum of all previous steps constrained to zero-mean - weights[lo:hi, c:c + 1, m:m + 1] += weight_update - return weights - - def update_eigen_probe( - comm, R, eigen_probe, weights, @@ -433,8 +360,6 @@ def update_eigen_probe( Parameters ---------- - comm : :py:class:`tike.communicators.Comm` - An object which manages communications between both GPUs and nodes. R : (POSI, 1, 1, WIDE, HIGH) complex64 Residual probe updates; what's left after subtracting the shared probe update from the varying probe updates for each position @@ -459,55 +384,73 @@ def update_eigen_probe( least-squares solver for generalized maximum-likelihood ptychography. Optics Express. 2018. """ - assert R[0].shape[-3] == R[0].shape[-4] == 1 - assert 1 == eigen_probe[0].shape[-5] - assert R[0].shape[:-5] == eigen_probe[0].shape[:-5] == weights[0].shape[:-3] - assert weights[0][batches[0][batch_index], :, :].shape[-3] == R[0].shape[-5] - assert R[0].shape[-2:] == eigen_probe[0].shape[-2:] - - update = comm.pool.map( - _get_update, - R, - eigen_probe, - weights, - batches, - batch_index=batch_index, - c=c, - m=m, + assert R.shape[-3] == R.shape[-4] == 1 + assert 1 == eigen_probe.shape[-5] + assert R.shape[:-5] == eigen_probe.shape[:-5] == weights.shape[:-3] + assert weights[batches[batch_index], :, :].shape[-3] == R.shape[-5] + assert R.shape[-2:] == eigen_probe.shape[-2:] + + lo = batches[batch_index][0] + hi = lo + len(batches[batch_index]) + # (POSI, 1, 1, 1, 1) to match other arrays + norm_weights = ( + tike.linalg.norm( + weights[lo:hi, c : c + 1, m : m + 1, None, None], + axis=-5, + keepdims=True, + ) + ** 2 ) - update = comm.pool.bcast([comm.Allreduce_mean( - update, + + if np.all(norm_weights == 0): + raise ValueError("eigen_probe weights cannot all be zero?") + + # FIXME: What happens when weights is zero!? + proj = ( + np.real(R.conj() * eigen_probe[:, c - 1 : c, m : m + 1, :, :]) + + weights[lo:hi, c : c + 1, m : m + 1, None, None] + ) / norm_weights + update = np.mean( + R * np.mean(proj, axis=(-2, -1), keepdims=True), axis=-5, - )]) - - (eigen_probe, n, d, d_mean) = (list(a) for a in zip(*comm.pool.map( - _get_d, - patches, - diff, - eigen_probe, - update, - β=β, - c=c, - m=m, - ))) - - d_mean = comm.pool.bcast([comm.Allreduce_mean( - d_mean, - axis=-3, - )]) - - weights = list( - comm.pool.map( - _get_weights_mean, - n, - d, - d_mean, - weights, - batches, - batch_index=batch_index, - c=c, - m=m, - )) + keepdims=False, + ) + + eigen_probe[:, c - 1 : c, m : m + 1, :, :] += ( + β + * update + / tike.linalg.mnorm( + update, + axis=(-2, -1), + keepdims=True, + ) + ) + eigen_probe[:, c - 1 : c, m : m + 1, :, :] /= tike.linalg.mnorm( + eigen_probe[:, c - 1 : c, m : m + 1, :, :], + axis=(-2, -1), + keepdims=True, + ) + assert np.all(np.isfinite(eigen_probe)) + + # Determine new eigen_weights for the updated eigen probe + phi = patches * eigen_probe[:, c - 1 : c, m : m + 1, :, :] + n = np.mean( + np.real(diff[:, :, m : m + 1, :, :] * phi.conj()), + axis=(-1, -2), + keepdims=False, + ) + d = np.mean(np.square(np.abs(phi)), axis=(-1, -2), keepdims=False) + d_mean = np.mean(d, axis=-3, keepdims=False) + + # yapf: disable + weight_update = ( + n / (d + 0.1 * d_mean) + ).reshape(*weights[lo:hi, c:c + 1, m:m + 1].shape) + # yapf: enable + assert np.all(np.isfinite(weight_update)) + + # (33) The sum of all previous steps constrained to zero-mean + weights[lo:hi, c : c + 1, m : m + 1] += weight_update return eigen_probe, weights @@ -941,7 +884,13 @@ def constrain_probe_sparsity(probe, f): return probe -def finite_probe_support(probe, *, radius=0.5, degree=5, p=1.0): +def finite_probe_support( + probe, + *, + radius: float = 0.5, + degree: float = 5.0, + p: float = 1.0, +): """Returns a supergaussian penalty function for finite probe support. A mask which provides an illumination penalty is determined by the equation: diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 3e78a661..b0228dee 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -77,12 +77,14 @@ PositionOptions, check_allowed_positions, affine_position_regularization, + AffineTransform, ) from .probe import ( constrain_center_peak, constrain_probe_sparsity, get_varying_probe, apply_median_filter_abs_probe, + orthogonalize_eig, ) logger = logging.getLogger(__name__) @@ -226,27 +228,28 @@ def reconstruct( use_mpi, ) as context: context.iterate(parameters.algorithm_options.num_iter) + result = context.get_result() if ( logger.getEffectiveLevel() <= logging.INFO - ) and context.parameters.position_options: + ) and result.position_options: mean_scaling = 0.5 * ( - context.parameters.position_options.transform.scale0 - + context.parameters.position_options.transform.scale1 + result.position_options.transform.scale0 + + result.position_options.transform.scale1 ) logger.info( f"Global scaling of {mean_scaling:.3e} detected from position correction." " Probably your estimate of photon energy and/or sample to detector " "distance is off by that amount." ) - t = context.parameters.position_options.transform.asarray() + t = result.position_options.transform.asarray() logger.info(f"""Affine transform parameters: {t[0,0]: .3e}, {t[0,1]: .3e} {t[1,0]: .3e}, {t[1,1]: .3e} """) - return context.parameters + return result def _clip_magnitude(x, a_max): @@ -333,8 +336,10 @@ def __init__( else: mpi = tike.communicators.NoMPIComm - self.data = data - self.parameters = copy.deepcopy(parameters) + self.data: typing.List[npt.ArrayLike] = [data] + self.parameters: typing.List[solvers.PtychoParameters] = [ + copy.deepcopy(parameters) + ] self.device = cp.cuda.Device( num_gpu[0] if isinstance(num_gpu, tuple) else None) self.operator = tike.operators.Ptycho( @@ -352,7 +357,7 @@ def __enter__(self): self.comm.__enter__() # Divide the inputs into regions - if (not np.all(np.isfinite(self.data)) or np.any(self.data < 0)): + if not np.all(np.isfinite(self.data[0])) or np.any(self.data[0] < 0): warnings.warn( "Diffraction patterns contain invalid data. " "All data should be non-negative and finite.", UserWarning) @@ -360,326 +365,229 @@ def __enter__(self): ( self.comm.order, self.batches, - self.parameters.scan, - self.data, - self.parameters.eigen_weights, + self.comm.stripe_start, ) = tike.cluster.by_scan_stripes_contiguous( - self.data, - self.parameters.eigen_weights, - scan=self.parameters.scan, + scan=self.parameters[0].scan, pool=self.comm.pool, shape=(self.comm.pool.num_workers, 1), - dtype=( - tike.precision.floating, - tike.precision.floating - if self.data.itemsize > 2 else self.data.dtype, - tike.precision.floating, - ), - destination=('gpu', 'pinned', 'gpu'), - batch_method=self.parameters.algorithm_options.batch_method, - num_batch=self.parameters.algorithm_options.num_batch, + batch_method=self.parameters[0].algorithm_options.batch_method, + num_batch=self.parameters[0].algorithm_options.num_batch, ) - self.parameters.psi = self.comm.pool.bcast( - [self.parameters.psi.astype(tike.precision.cfloating)]) - - self.parameters.probe = self.comm.pool.bcast( - [self.parameters.probe.astype(tike.precision.cfloating)]) - - if self.parameters.probe_options is not None: - self.parameters.probe_options = self.parameters.probe_options.copy_to_device( - self.comm,) - - if self.parameters.object_options is not None: - self.parameters.object_options = self.parameters.object_options.copy_to_device( - self.comm,) - - if self.parameters.exitwave_options is not None: - self.parameters.exitwave_options = self.parameters.exitwave_options.copy_to_device( - self.comm,) - - if self.parameters.eigen_probe is not None: - self.parameters.eigen_probe = self.comm.pool.bcast( - [self.parameters.eigen_probe.astype(tike.precision.cfloating)]) + self.data = self.comm.pool.map( + tike.cluster._split_pinned, + self.comm.order, + x=self.data[0], + dtype=tike.precision.floating + if self.data[0].itemsize > 2 + else self.data[0].dtype, + ) - if self.parameters.position_options is not None: - # TODO: Consider combining put/split, get/join operations? - self.parameters.position_options = self.comm.pool.map( - PositionOptions.copy_to_device, - (self.parameters.position_options.split(x) - for x in self.comm.order), - ) + self.parameters = self.comm.pool.map( + solvers.PtychoParameters.split, + self.comm.order, + x=self.parameters[0], + ) + assert len(self.parameters) == self.comm.pool.num_workers, ( + len(self.parameters), + self.comm.pool.num_workers, + ) + assert self.parameters[0].psi.dtype == tike.precision.cfloating, self.parameters[0].psi.dtype + assert self.parameters[0].probe.dtype == tike.precision.cfloating, self.parameters[0].probe.dtype + assert self.parameters[0].scan.dtype == tike.precision.floating, self.parameters[0].probe.dtype - if self.parameters.probe_options is not None: + self.parameters = self.comm.pool.map( + solvers.PtychoParameters.copy_to_device, + self.parameters, + ) + assert len(self.parameters) == self.comm.pool.num_workers, ( + len(self.parameters), + self.comm.pool.num_workers, + ) + assert self.parameters[0].psi.dtype == tike.precision.cfloating, self.parameters[0].psi.dtype + assert self.parameters[0].probe.dtype == tike.precision.cfloating, self.parameters[0].probe.dtype + assert self.parameters[0].scan.dtype == tike.precision.floating, self.parameters[0].probe.dtype - if self.parameters.probe_options.init_rescale_from_measurements: - self.parameters.probe = _rescale_probe( + if self.parameters[0].probe_options is not None: + if self.parameters[0].probe_options.init_rescale_from_measurements: + self.parameters = _rescale_probe( self.operator, self.comm, self.data, - self.parameters.exitwave_options, - self.parameters.psi, - self.parameters.scan, - self.parameters.probe, - num_batch=self.parameters.algorithm_options.num_batch, + self.parameters, ) - - if np.isnan(self.parameters.probe_options.probe_photons): - self.parameters.probe_options.probe_photons = np.sum( - np.abs(self.parameters.probe[0].get())**2) + assert self.parameters[0].psi.dtype == tike.precision.cfloating, self.parameters[0].psi.dtype + assert self.parameters[0].probe.dtype == tike.precision.cfloating, self.parameters[0].probe.dtype + assert self.parameters[0].scan.dtype == tike.precision.floating, self.parameters[0].probe.dtype return self def iterate(self, num_iter: int) -> None: """Advance the reconstruction by num_iter epochs.""" start = time.perf_counter() - psi_previous = self.parameters.psi[0].copy() + # psi_previous = self.parameters[0].psi.copy() for i in range(num_iter): if ( - np.sum(self.parameters.algorithm_options.times) - > self.parameters.algorithm_options.time_limit + np.sum(self.parameters[0].algorithm_options.times) + > self.parameters[0].algorithm_options.time_limit ): logger.info("Maximum reconstruction time exceeded.") break - logger.info(f"{self.parameters.algorithm_options.name} epoch " - f"{len(self.parameters.algorithm_options.times):,d}") - - total_epochs = len(self.parameters.algorithm_options.times) - - if self.parameters.probe_options is not None: - self.parameters.probe_options.recover_probe = ( - total_epochs >= self.parameters.probe_options.update_start - and (total_epochs % self.parameters.probe_options.update_period) == 0 - ) # yapf: disable - - if self.parameters.probe_options is not None: - if self.parameters.probe_options.recover_probe: - - if self.parameters.probe_options.median_filter_abs_probe: - self.parameters.probe = self.comm.pool.map( - apply_median_filter_abs_probe, - self.parameters.probe, - med_filt_px = self.parameters.probe_options.median_filter_abs_probe_px - ) - - if self.parameters.probe_options.force_centered_intensity: - self.parameters.probe = self.comm.pool.map( - constrain_center_peak, - self.parameters.probe, - ) - - if self.parameters.probe_options.force_sparsity < 1: - self.parameters.probe = self.comm.pool.map( - constrain_probe_sparsity, - self.parameters.probe, - f=self.parameters.probe_options.force_sparsity, - ) - - if self.parameters.probe_options.force_orthogonality: - ( - self.parameters.probe, - power, - ) = (list(a) for a in zip(*self.comm.pool.map( - tike.ptycho.probe.orthogonalize_eig, - self.parameters.probe, - ))) - else: - power = self.comm.pool.map( - tike.ptycho.probe.power, - self.parameters.probe, - ) - - self.parameters.probe_options.power.append( - power[0].get()) + logger.info( + f"{self.parameters[0].algorithm_options.name} epoch " + f"{len(self.parameters[0].algorithm_options.times):,d}" + ) - ( - self.parameters.object_options, - self.parameters.probe_options, - ) = solvers.update_preconditioners( + total_epochs = len(self.parameters[0].algorithm_options.times) + + self.parameters = self.comm.pool.map( + _apply_probe_constraints, + self.parameters, + epoch=total_epochs + ) + + self.parameters = solvers.update_preconditioners( comm=self.comm, + parameters=self.parameters, operator=self.operator, - scan=self.parameters.scan, - probe=self.parameters.probe, - psi=self.parameters.psi, - object_options=self.parameters.object_options, - probe_options=self.parameters.probe_options, ) - self.parameters = getattr( - solvers, - self.parameters.algorithm_options.name, - )( - self.operator, - self.comm, - data=self.data, - batches=self.batches, - parameters=self.parameters, - epoch=len(self.parameters.algorithm_options.times), + self.parameters = self.comm.pool.map( + getattr(solvers, self.parameters[0].algorithm_options.name), + self.parameters, + self.data, + self.batches, + self.comm.pool.streams, + op=self.operator, + epoch=len(self.parameters[0].algorithm_options.times), ) - if self.parameters.object_options.positivity_constraint: - self.parameters.psi = self.comm.pool.map( - tike.ptycho.object.positivity_constraint, - self.parameters.psi, - r=self.parameters.object_options.positivity_constraint, + for i, reduced_probe in enumerate( + self.comm.Allreduce_mean( + [e.probe[None, ...] for e in self.parameters], + axis=0, ) + ): + self.parameters[i].probe = reduced_probe - if self.parameters.object_options.smoothness_constraint: - self.parameters.psi = self.comm.pool.map( - tike.ptycho.object.smoothness_constraint, - self.parameters.psi, - a=self.parameters.object_options.smoothness_constraint, - ) + if self.parameters[0].eigen_probe is not None: + for i, reduced_probe in enumerate( + self.comm.Allreduce_mean( + [e.eigen_probe[None, ...] for e in self.parameters], + axis=0, + ) + ): + self.parameters[i].eigen_probe = reduced_probe + + pw = self.parameters[0].probe.shape[-2] + for swapped, parameters in zip( + self.comm.swap_edges( + [e.psi for e in self.parameters], + # reduce overlap to stay away from edge noise + overlap=pw-1, + # The actual edge is centered on the probe + edges=self.comm.stripe_start, + ), + self.parameters, + ): + parameters.psi = swapped - if self.parameters.object_options.clip_magnitude: - self.parameters.psi = self.comm.pool.map( - _clip_magnitude, - self.parameters.psi, - a_max=1.0, + if self.parameters[0].position_options is not None: + # FIXME: Synchronize across nodes + reduced_transform = np.mean( + [e.position_options.transform.asbuffer() for e in self.parameters], + axis=0, ) - - if ( - self.parameters.algorithm_options.rescale_method == 'mean_of_abs_object' - and self.parameters.object_options.preconditioner is not None - and len(self.parameters.algorithm_options.costs) % self.parameters.algorithm_options.rescale_period == 0 - ): # yapf: disable - ( - self.parameters.psi, - self.parameters.probe, - ) = (list(a) for a in zip(*self.comm.pool.map( - tike.ptycho.object.remove_object_ambiguity, - self.parameters.psi, - self.parameters.probe, - self.parameters.object_options.preconditioner, - ))) - - elif self.parameters.probe_options is not None: - if ( - self.parameters.probe_options.recover_probe - and self.parameters.algorithm_options.rescale_method == 'constant_probe_photons' - and len(self.parameters.algorithm_options.costs) % self.parameters.algorithm_options.rescale_period == 0 - ): # yapf: disable - - self.parameters.probe = self.comm.pool.map( - tike.ptycho.probe - .rescale_probe_using_fixed_intensity_photons, - self.parameters.probe, - Nphotons=self.parameters.probe_options.probe_photons, - probe_power_fraction=None, + for i in range(len(self.parameters)): + self.parameters[ + i + ].position_options.transform = AffineTransform.frombuffer( + reduced_transform ) - if ( - self.parameters.probe_options is not None - and self.parameters.eigen_probe is not None - and self.parameters.probe_options.recover_probe - ): #yapf: disable - ( - self.parameters.eigen_probe, - self.parameters.eigen_weights, - ) = tike.ptycho.probe.constrain_variable_probe( - self.comm, - self.parameters.eigen_probe, - self.parameters.eigen_weights, - ) + self.parameters = self.comm.pool.map( + _apply_object_constraints, + self.parameters, + ) - if self.parameters.position_options: - ( - self.parameters.scan, - self.parameters.position_options, - ) = affine_position_regularization( - self.comm, - updated=self.parameters.scan, - position_options=self.parameters.position_options, - regularization_enabled=self.parameters.position_options[ - 0 - ].use_position_regularization, - ) + self.parameters = self.comm.pool.map( + _apply_position_constraints, + self.parameters, + ) + + reduced_cost = np.mean( + [e.algorithm_options.costs[-1] for e in self.parameters], + ) + for i in range(len(self.parameters)): + self.parameters[i].algorithm_options.costs[-1] = [reduced_cost] - self.parameters.algorithm_options.times.append(time.perf_counter() - - start) + self.parameters[0].algorithm_options.times.append( + time.perf_counter() - start + ) start = time.perf_counter() - update_norm = tike.linalg.mnorm(self.parameters.psi[0] - - psi_previous) + # update_norm = tike.linalg.mnorm(self.parameters.psi[0] - + # psi_previous) - self.parameters.object_options.update_mnorm.append( - update_norm.get()) + # self.parameters.object_options.update_mnorm.append( + # update_norm.get()) - logger.info(f"The object update mean-norm is {update_norm:.3e}") + # logger.info(f"The object update mean-norm is {update_norm:.3e}") - if (np.mean(self.parameters.object_options.update_mnorm[-5:]) - < self.parameters.object_options.convergence_tolerance): - logger.info( - f"The object seems converged. {update_norm:.3e} < " - f"{self.parameters.object_options.convergence_tolerance:.3e}" - ) - break + # if (np.mean(self.parameters.object_options.update_mnorm[-5:]) + # < self.parameters.object_options.convergence_tolerance): + # logger.info( + # f"The object seems converged. {update_norm:.3e} < " + # f"{self.parameters.object_options.convergence_tolerance:.3e}" + # ) + # break logger.info( - '%10s cost is %+1.3e', - self.parameters.exitwave_options.noise_model, - np.mean(self.parameters.algorithm_options.costs[-1]), + "%10s cost is %+1.3e", + self.parameters[0].exitwave_options.noise_model, + np.mean(self.parameters[0].algorithm_options.costs[-1]), ) - def get_scan(self): + def get_scan(self) -> npt.NDArray: reorder = np.argsort(np.concatenate(self.comm.order)) - return self.comm.pool.gather_host( - self.parameters.scan, - axis=-2, + return np.concatenate( + [cp.asnumpy(e.scan) for e in self.parameters], + axis=0, )[reorder] - def get_result(self): + def get_result(self) -> solvers.PtychoParameters: """Return the current parameter estimates.""" reorder = np.argsort(np.concatenate(self.comm.order)) - parameters = solvers.PtychoParameters( - probe=self.parameters.probe[0].get(), - psi=self.parameters.psi[0].get(), - scan=self.comm.pool.gather_host( - self.parameters.scan, - axis=-2, - )[reorder], - algorithm_options=self.parameters.algorithm_options, - ) - - if self.parameters.eigen_probe is not None: - parameters.eigen_probe = self.parameters.eigen_probe[0].get() - - if self.parameters.eigen_weights is not None: - parameters.eigen_weights = self.comm.pool.gather( - self.parameters.eigen_weights, - axis=-3, - )[reorder].get() - - if self.parameters.probe_options is not None: - parameters.probe_options = self.parameters.probe_options.copy_to_host( - ) - if self.parameters.object_options is not None: - parameters.object_options = self.parameters.object_options.copy_to_host( - ) + assert len(self.parameters) == self.comm.pool.num_workers, ( + len(self.parameters), + self.comm.pool.num_workers, + ) - if self.parameters.exitwave_options is not None: - parameters.exitwave_options = self.parameters.exitwave_options.copy_to_host( + # Use plain map here instead of threaded map so this method can be + # called when the context is closed. + parameters = list( + map( + solvers.PtychoParameters.copy_to_host, + self.parameters, ) + ) - if self.parameters.position_options is not None: - host_position_options = self.parameters.position_options[0].empty() - for x, o in zip( - self.comm.pool.map( - PositionOptions.copy_to_host, - self.parameters.position_options, - ), - self.comm.order, - ): - host_position_options = host_position_options.join(x, o) - parameters.position_options = host_position_options + parameters = solvers.PtychoParameters.join( + parameters, + reorder, + stripe_start=self.comm.stripe_start, + ) return parameters def __exit__(self, type, value, traceback): - self.parameters = self.get_result() + self.parameters = self.comm.pool.map( + solvers.PtychoParameters.copy_to_host, + self.parameters, + ) self.comm.__exit__(type, value, traceback) self.operator.__exit__(type, value, traceback) self.device.__exit__(type, value, traceback) @@ -693,41 +601,35 @@ def get_convergence( ) -> typing.Tuple[typing.List[typing.List[float]], typing.List[float]]: """Return the cost function values and times as a tuple.""" return ( - self.parameters.algorithm_options.costs, - self.parameters.algorithm_options.times, + self.parameters[0].algorithm_options.costs, + self.parameters[0].algorithm_options.times, ) def get_psi(self) -> np.array: """Return the current object estimate as a numpy array.""" - return self.parameters.psi[0].get() + return ObjectOptions.join_psi( + [cp.asnumpy(e.psi) for e in self.parameters], + probe_width=self.parameters[0].probe.shape[-2], + stripe_start=self.comm.stripe_start, + ) def get_probe(self) -> typing.Tuple[np.array, np.array, np.array]: """Return the current probe, eigen_probe, weights as numpy arrays.""" - reorder = np.argsort(np.concatenate(self.comm.order)) - if self.parameters.eigen_probe is None: + if self.parameters[0].eigen_probe is None: eigen_probe = None else: - eigen_probe = self.parameters.eigen_probe[0].get() + eigen_probe = self.parameters[0].eigen_probe.get() if self.parameters.eigen_weights is None: eigen_weights = None else: + reorder = np.argsort(np.concatenate(self.comm.order)) eigen_weights = self.comm.pool.gather( self.parameters.eigen_weights, axis=-3, )[reorder].get() - probe = self.parameters.probe[0].get() + probe = self.parameters[0].probe.get() return probe, eigen_probe, eigen_weights - def peek(self) -> typing.Tuple[np.array, np.array, np.array, np.array]: - """Return the curent values of object and probe as numpy arrays. - - Parameters returned in a tuple of object, probe, eigen_probe, - eigen_weights. - """ - psi = self.get_psi() - probe, eigen_probe, eigen_weights = self.get_probe() - return psi, probe, eigen_probe, eigen_weights - def append_new_data( self, new_data: npt.NDArray, @@ -805,18 +707,139 @@ def append_new_data( new_scan, ) +def _apply_probe_constraints( + parameters: solvers.PtychoParameters, + *, + epoch: int, +) -> solvers.PtychoParameters: + if parameters.probe_options is not None: + if parameters.probe_options.recover_probe(epoch): + + if parameters.probe_options.median_filter_abs_probe: + parameters.probe = apply_median_filter_abs_probe( + parameters.probe, + med_filt_px=parameters.probe_options.median_filter_abs_probe_px, + ) + + if parameters.probe_options.force_centered_intensity: + parameters.probe = constrain_center_peak( + parameters.probe, + ) + + if parameters.probe_options.force_sparsity < 1: + parameters.probe = constrain_probe_sparsity( + parameters.probe, + f=parameters.probe_options.force_sparsity, + ) + + if parameters.probe_options.force_orthogonality: + ( + parameters.probe, + power, + ) = orthogonalize_eig( + parameters.probe, + ) + else: + power = tike.ptycho.probe.power( + parameters.probe, + ) + + parameters.probe_options.power.append(cp.asnumpy(power)) + + if parameters.algorithm_options.rescale_method == "constant_probe_photons" and ( + len(parameters.algorithm_options.costs) + % parameters.algorithm_options.rescale_period + == 0 + ): + parameters.probe = ( + tike.ptycho.probe.rescale_probe_using_fixed_intensity_photons( + parameters.probe, + Nphotons=parameters.probe_options.probe_photons, + probe_power_fraction=None, + ) + ) + + if ( + parameters.eigen_probe is not None + and parameters.probe_options.recover_probe(epoch) + ): + ( + parameters.eigen_probe, + parameters.eigen_weights, + ) = tike.ptycho.probe.constrain_variable_probe( + parameters.eigen_probe, + parameters.eigen_weights, + ) + + return parameters + + +def _apply_object_constraints( + parameters: solvers.PtychoParameters, +) -> solvers.PtychoParameters: + if parameters.object_options.positivity_constraint: + parameters.psi = tike.ptycho.object.positivity_constraint( + parameters.psi, + r=parameters.object_options.positivity_constraint, + ) + + if parameters.object_options.smoothness_constraint: + parameters.psi = tike.ptycho.object.smoothness_constraint( + parameters.psi, + a=parameters.object_options.smoothness_constraint, + ) + + if parameters.object_options.clip_magnitude: + parameters.psi = _clip_magnitude( + parameters.psi, + a_max=1.0, + ) + + if ( + parameters.algorithm_options.name != "dm" + and parameters.algorithm_options.rescale_method == "mean_of_abs_object" + and parameters.object_options.preconditioner is not None + and ( + len(parameters.algorithm_options.costs) + % parameters.algorithm_options.rescale_period + == 0 + ) + ): + ( + parameters.psi, + parameters.probe, + ) = tike.ptycho.object.remove_object_ambiguity( + parameters.psi, + parameters.probe, + parameters.object_options.preconditioner, + ) + + return parameters + + +def _apply_position_constraints( + parameters: solvers.PtychoParameters, +) -> solvers.PtychoParameters: + if parameters.position_options: + ( + parameters.scan, + parameters.position_options, + ) = affine_position_regularization( + updated=parameters.scan, + position_options=parameters.position_options, + ) + + return parameters + def _order_join(a, b): return np.append(a, b + len(a)) def _get_rescale( - data, - measured_pixels, - psi, - scan, - probe, - streams, + data: npt.ArrayLike, + parameters: solvers.PtychoParameters, + streams: typing.List[cp.cuda.Stream], *, operator: tike.operators.Ptycho, ): @@ -832,17 +855,21 @@ def make_certain_args_constant( ( data, ) = ind_args - nonlocal sums, scan + nonlocal sums intensity, _ = operator._compute_intensity( None, - psi, - scan[lo:hi], - probe, + parameters.psi, + parameters.scan[lo:hi], + parameters.probe, ) - sums[0] += cp.sum(data[:, measured_pixels], dtype=np.double) - sums[1] += cp.sum(intensity[:, measured_pixels], dtype=np.double) + sums[0] += cp.sum( + data[:, parameters.exitwave_options.measured_pixels], dtype=np.double + ) + sums[1] += cp.sum( + intensity[:, parameters.exitwave_options.measured_pixels], dtype=np.double + ) tike.communicators.stream.stream_and_modify2( f=make_certain_args_constant, @@ -857,8 +884,12 @@ def make_certain_args_constant( return sums -def _rescale_probe(operator, comm, data, exitwave_options, psi, scan, probe, - num_batch): +def _rescale_probe( + operator: tike.operators.Ptycho, + comm: tike.communicators.Comm, + data: typing.List[npt.ArrayLike], + parameters: typing.List[solvers.PtychoParameters], +): """Rescale probe so model and measured intensity are similar magnitude. Rescales the probe so that the sum of modeled intensity at the detector is @@ -868,11 +899,8 @@ def _rescale_probe(operator, comm, data, exitwave_options, psi, scan, probe, n = comm.pool.map( _get_rescale, data, - exitwave_options.measured_pixels, - psi, - scan, - probe, - comm.streams, + parameters, + comm.pool.streams, operator=operator, ) except cp.cuda.memory.OutOfMemoryError: @@ -883,13 +911,31 @@ def _rescale_probe(operator, comm, data, exitwave_options, psi, scan, probe, n = np.sqrt(comm.Allreduce_reduce_cpu(n)) - rescale = cp.asarray(n[0] / n[1]) + # Force precision to prevent type promotion downstream + rescale = cp.asarray(n[0] / n[1], dtype=tike.precision.floating) logger.info("Probe rescaled by %f", rescale) - probe[0] *= rescale + rescale = comm.pool.bcast([rescale]) + return comm.pool.map( + _rescale_probe_helper, + parameters, + rescale, + ) + + +def _rescale_probe_helper( + parameters: solvers.PtychoParameters, + rescale: float, +) -> solvers.PtychoParameters: + parameters.probe = parameters.probe * rescale + + if np.isnan(parameters.probe_options.probe_photons): + parameters.probe_options.probe_photons = cp.sum( + cp.square(cp.abs(parameters.probe)) + ).get() - return comm.pool.bcast([probe[0]]) + return parameters def reconstruct_multigrid( @@ -938,29 +984,30 @@ def reconstruct_multigrid( use_mpi=use_mpi, ) as context: context.iterate(resampled_parameters.algorithm_options.num_iter) + result = context.get_result() if level == 0: if ( logger.getEffectiveLevel() <= logging.INFO - ) and context.parameters.position_options: + ) and result.position_options: mean_scaling = 0.5 * ( - context.parameters.position_options.transform.scale0 - + context.parameters.position_options.transform.scale1 + result.position_options.transform.scale0 + + result.position_options.transform.scale1 ) logger.info( f"Global scaling of {mean_scaling:.3e} detected from position correction." " Probably your estimate of photon energy and/or sample to detector " "distance is off by that amount." ) - t = context.parameters.position_options.transform.asarray() + t = result.position_options.transform.asarray() logger.info(f"""Affine transform parameters: {t[0,0]: .3e}, {t[0,1]: .3e} {t[1,0]: .3e}, {t[1,1]: .3e} """) - return context.parameters + return result # Upsample result to next grid - resampled_parameters = context.parameters.resample(2.0, interp) + resampled_parameters = result.resample(2.0, interp) raise RuntimeError('This should not happen.') diff --git a/src/tike/ptycho/solvers/_preconditioner.py b/src/tike/ptycho/solvers/_preconditioner.py index 6234be70..64857f79 100644 --- a/src/tike/ptycho/solvers/_preconditioner.py +++ b/src/tike/ptycho/solvers/_preconditioner.py @@ -7,12 +7,27 @@ import tike.operators import tike.precision -from .options import ObjectOptions, ProbeOptions +from .options import ObjectOptions, ProbeOptions, PtychoParameters -@cp.fuse() -def _rolling_average(old, new): - return 0.5 * (new + old) +def _rolling_average_object(parameters: PtychoParameters, new): + if parameters.object_options.preconditioner is None: + parameters.object_options.preconditioner = new + else: + parameters.object_options.preconditioner = 0.5 * ( + new + parameters.object_options.preconditioner + ) + return parameters + + +def _rolling_average_probe(parameters: PtychoParameters, new): + if parameters.probe_options.preconditioner is None: + parameters.probe_options.preconditioner = new + else: + parameters.probe_options.preconditioner = 0.5 * ( + new + parameters.probe_options.preconditioner + ) + return parameters @cp.fuse() @@ -24,17 +39,15 @@ def _probe_amp_sum(probe): def _psi_preconditioner( - psi: npt.NDArray[tike.precision.cfloating], - scan: npt.NDArray[tike.precision.floating], - probe: npt.NDArray[tike.precision.cfloating], + parameters: PtychoParameters, streams: typing.List[cp.cuda.Stream], *, operator: tike.operators.Ptycho, ) -> npt.NDArray: psi_update_denominator = cp.zeros( - shape=psi.shape, - dtype=psi.dtype, + shape=parameters.psi.shape, + dtype=parameters.psi.dtype, ) def make_certain_args_constant( @@ -44,26 +57,26 @@ def make_certain_args_constant( ) -> None: nonlocal psi_update_denominator - probe_amp = _probe_amp_sum(probe[:, 0]) + probe_amp = _probe_amp_sum(parameters.probe)[:, 0] psi_update_denominator[0] = operator.diffraction.patch.adj( patches=probe_amp, images=psi_update_denominator[0], - positions=scan[lo:hi], + positions=parameters.scan[lo:hi], ) - probe1 = probe[:, 0] - for i in range(1, len(psi)): + probe1 = parameters.probe[:, 0] + for i in range(1, len(parameters.psi)): probe1 = operator.diffraction.diffraction.fwd( probe=probe1, - scan=scan[lo:hi], - psi=psi[i-1], + scan=parameters.scan[lo:hi], + psi=parameters.psi[i-1], ) probe1 = operator.diffraction.propagation.fwd(probe1) probe_amp = _probe_amp_sum(probe1) psi_update_denominator[i] = operator.diffraction.patch.adj( patches=probe_amp, images=psi_update_denominator[i], - positions=scan[lo:hi], + positions=parameters.scan[lo:hi], ) tike.communicators.stream.stream_and_modify2( @@ -71,7 +84,7 @@ def make_certain_args_constant( ind_args=[], streams=streams, lo=0, - hi=len(scan), + hi=len(parameters.scan), ) return psi_update_denominator @@ -87,17 +100,15 @@ def _patch_amp_sum(patches): def _probe_preconditioner( - psi: npt.NDArray[tike.precision.cfloating], - scan: npt.NDArray[tike.precision.floating], - probe: npt.NDArray[tike.precision.cfloating], + parameters: PtychoParameters, streams: typing.List[cp.cuda.Stream], *, operator: tike.operators.Ptycho, ) -> npt.NDArray: probe_update_denominator = cp.zeros( - shape=probe.shape[-2:], - dtype=probe.dtype, + shape=parameters.probe.shape[-2:], + dtype=parameters.probe.dtype, ) def make_certain_args_constant( @@ -109,9 +120,9 @@ def make_certain_args_constant( # FIXME: Only use the first slice for the probe preconditioner patches = operator.diffraction.patch.fwd( - images=psi[0], - positions=scan[lo:hi], - patch_width=probe.shape[-1], + images=parameters.psi[0], + positions=parameters.scan[lo:hi], + patch_width=parameters.probe.shape[-1], ) probe_update_denominator[...] += _patch_amp_sum(patches) assert probe_update_denominator.ndim == 2 @@ -121,7 +132,7 @@ def make_certain_args_constant( ind_args=[], streams=streams, lo=0, - hi=len(scan), + hi=len(parameters.scan), ) return probe_update_denominator @@ -129,56 +140,41 @@ def make_certain_args_constant( def update_preconditioners( comm: tike.communicators.Comm, + parameters: typing.List[PtychoParameters], operator: tike.operators.Ptycho, - scan, - probe, - psi, - object_options: typing.Optional[ObjectOptions] = None, - probe_options: typing.Optional[ProbeOptions] = None, -) -> typing.Tuple[ObjectOptions, ProbeOptions]: +) -> typing.List[PtychoParameters]: """Update the probe and object preconditioners.""" - if object_options: + if parameters[0].object_options: preconditioner = comm.pool.map( _psi_preconditioner, - psi, - scan, - probe, - comm.streams, + parameters, + comm.pool.streams, operator=operator, ) - preconditioner = comm.Allreduce(preconditioner) + # preconditioner = comm.Allreduce(preconditioner) - if object_options.preconditioner is None: - object_options.preconditioner = preconditioner - else: - object_options.preconditioner = comm.pool.map( - _rolling_average, - object_options.preconditioner, - preconditioner, - ) - - if probe_options: + parameters = comm.pool.map( + _rolling_average_object, + parameters, + preconditioner, + ) + if parameters[0].probe_options: preconditioner = comm.pool.map( _probe_preconditioner, - psi, - scan, - probe, - comm.streams, + parameters, + comm.pool.streams, operator=operator, ) - preconditioner = comm.Allreduce(preconditioner) + # preconditioner = comm.Allreduce(preconditioner) - if probe_options.preconditioner is None: - probe_options.preconditioner = preconditioner - else: - probe_options.preconditioner = comm.pool.map( - _rolling_average, - probe_options.preconditioner, - preconditioner, - ) + parameters = comm.pool.map( + _rolling_average_probe, + parameters, + preconditioner, + ) - return object_options, probe_options + return parameters diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index ece6f3fd..3700d123 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -23,14 +23,14 @@ def lstsq_grad( - op: tike.operators.Ptycho, - comm: tike.communicators.Comm, - data: typing.List[npt.NDArray], + parameters: PtychoParameters, + data: npt.NDArray, batches: typing.List[npt.NDArray[cp.intc]], + streams: typing.List[cp.cuda.Stream], *, - parameters: PtychoParameters, + op: tike.operators.Ptycho, epoch: int, -): +) -> PtychoParameters: """Solve the ptychography problem using Odstrcil et al's approach. Object and probe are updated simultaneously using optimal step sizes @@ -69,55 +69,35 @@ def lstsq_grad( .. seealso:: :py:mod:`tike.ptycho` """ - probe = parameters.probe scan = parameters.scan psi = parameters.psi - + probe = parameters.probe algorithm_options = parameters.algorithm_options - - probe_options = parameters.probe_options - if probe_options is None: - recover_probe = False - else: - recover_probe = probe_options.recover_probe - + eigen_weights = parameters.eigen_weights + eigen_probe = parameters.eigen_probe + measured_pixels = parameters.exitwave_options.measured_pixels + exitwave_options = parameters.exitwave_options position_options = parameters.position_options object_options = parameters.object_options - exitwave_options = parameters.exitwave_options - - eigen_probe = parameters.eigen_probe - eigen_weights = parameters.eigen_weights - - position_update_numerator = [None] * comm.pool.num_workers - position_update_denominator = [None] * comm.pool.num_workers - - if eigen_probe is None: - beigen_probe = [None] * comm.pool.num_workers - else: - beigen_probe = eigen_probe - - if eigen_weights is None: - beigen_weights = [None] * comm.pool.num_workers - else: - beigen_weights = eigen_weights - - if object_options is not None: - if algorithm_options.batch_method == 'compact': - object_options.combined_update = cp.zeros_like(psi[0]) + probe_options = parameters.probe_options + recover_probe = probe_options is not None and epoch >= probe_options.update_start - if recover_probe: - probe_options.probe_update_sum = cp.zeros_like(probe[0]) + # CONVERSTION AREA ABOVE --------------------------------------- if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation - batch_cost = [] - beta_object = [] - beta_probe = [] - for batch_index in order(algorithm_options.num_batch): + object_combined_update = cp.zeros_like(psi) + probe_combined_update = cp.zeros_like(probe) + position_update_numerator = None + position_update_denominator = None + batch_cost = cp.empty(algorithm_options.num_batch, dtype=tike.precision.floating) + beta_object: typing.List[float] = [] + beta_probe: typing.List[float] = [] + for batch_index in order(algorithm_options.num_batch): ( diff, unique_probe, @@ -129,21 +109,19 @@ def lstsq_grad( position_update_numerator, position_update_denominator, position_options, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_gradients, + ) = _get_nearplane_gradients( data, psi, scan, probe, - beigen_probe, - beigen_weights, + eigen_probe, + eigen_weights, batches, position_update_numerator, position_update_denominator, - [None] * comm.pool.num_workers if position_options is - None else position_options, - comm.streams, - exitwave_options.measured_pixels, + position_options, + streams, + measured_pixels, object_options.preconditioner, batch_index=batch_index, num_batch=algorithm_options.num_batch, @@ -152,31 +130,24 @@ def lstsq_grad( recover_psi=object_options is not None, recover_probe=recover_probe, recover_positions=position_options is not None, - ))) - position_options = None if position_options[ - 0] is None else position_options - - if object_options is not None: - object_upd_sum = comm.Allreduce(object_upd_sum) + ) if recover_probe: - m_probe_update = comm.pool.bcast( - [comm.Allreduce_mean( - m_probe_update, - axis=-5, - )]) + m_probe_update = cp.mean( + m_probe_update, + axis=-5, + ) ( - beigen_probe, - beigen_weights, + eigen_probe, + eigen_weights, ) = _update_nearplane( - comm, diff, probe_update, m_probe_update, probe, - beigen_probe, - beigen_weights, + eigen_probe, + eigen_weights, patches, batches, batch_index=batch_index, @@ -190,8 +161,7 @@ def lstsq_grad( A4, b1, b2, - ) = (list(a) for a in zip(*comm.pool.map( - _precondition_nearplane_gradients, + ) = _precondition_nearplane_gradients( diff, scan, unique_probe, @@ -207,23 +177,22 @@ def lstsq_grad( recover_psi=object_options is not None, recover_probe=recover_probe, probe_options=probe_options, - ))) + ) if object_options is not None: - A1_delta = comm.pool.bcast([comm.Allreduce_mean(A1, axis=-3)]) + A1_delta = cp.mean(A1, axis=-3) else: - A1_delta = [None] * comm.pool.num_workers + A1_delta = None if recover_probe: - A4_delta = comm.pool.bcast([comm.Allreduce_mean(A4, axis=-3)]) + A4_delta = cp.mean(A4, axis=-3) else: - A4_delta = [None] * comm.pool.num_workers + A4_delta = None ( weighted_step_psi, weighted_step_probe, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_steps, + ) = _get_nearplane_steps( A1, A2, A4, @@ -234,16 +203,16 @@ def lstsq_grad( recover_psi=object_options is not None, recover_probe=recover_probe, m=0, - ))) + ) if object_options is not None: - bbeta_object = comm.Allreduce_mean( + bbeta_object = cp.mean( weighted_step_psi, axis=-5, )[..., 0, 0, 0] if recover_probe: - bbeta_probe = comm.Allreduce_mean( + bbeta_probe = cp.mean( weighted_step_probe, axis=-5, ) @@ -252,7 +221,7 @@ def lstsq_grad( if object_options is not None: if algorithm_options.batch_method != 'compact': # (27b) Object update - dpsi = bbeta_object[0] * object_update_precond[0] + dpsi = bbeta_object * object_update_precond if object_options.use_adaptive_moment: ( @@ -266,55 +235,47 @@ def lstsq_grad( vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) - psi[0] = psi[0] + dpsi - psi = comm.pool.bcast([psi[0]]) + psi = psi + dpsi else: - object_options.combined_update += object_upd_sum[0] - - if recover_probe: - dprobe = bbeta_probe[0] * m_probe_update[0] - probe_options.probe_update_sum += dprobe / algorithm_options.num_batch - # (27a) Probe update - probe[0] += dprobe - probe = comm.pool.bcast([probe[0]]) + object_combined_update += object_upd_sum - for c in costs: - batch_cost = batch_cost + c.tolist() - - if object_options is not None: beta_object.append(bbeta_object) if recover_probe: + dprobe = bbeta_probe * m_probe_update + probe_combined_update += dprobe / algorithm_options.num_batch + # (27a) Probe update + probe += dprobe + beta_probe.append(bbeta_probe) - if eigen_probe is not None: - eigen_probe = beigen_probe + batch_cost[batch_index] = cp.mean(costs) - if eigen_weights is not None: - eigen_weights = beigen_weights - - if position_options: - scan, position_options = zip(*comm.pool.map( - _update_position, + if ( + position_options is not None + and position_update_numerator is not None + and position_update_denominator is not None + ): + scan, position_options = _update_position( scan, position_options, position_update_numerator, position_update_denominator, epoch=epoch, - )) + ) - algorithm_options.costs.append(batch_cost) + algorithm_options.costs.append(batch_cost.tolist()) if object_options and algorithm_options.batch_method == 'compact': object_update_precond = _precondition_object_update( - object_options.combined_update, - object_options.preconditioner[0], + object_combined_update, + object_options.preconditioner, ) # (27b) Object update beta_object = cp.mean(cp.stack(beta_object)) dpsi = beta_object * object_update_precond - psi[0] = psi[0] + dpsi + psi = psi + dpsi if object_options.use_adaptive_moment: ( @@ -326,27 +287,25 @@ def lstsq_grad( v=object_options.v, m=object_options.m, mdecay=object_options.mdecay, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + errors=list(float(np.mean(x)) for x in algorithm_options.costs[-3:]), beta=beta_object, memory_length=3, ) - weight = object_options.preconditioner[0] + weight = object_options.preconditioner weight = weight / (0.1 * weight.max() + weight) - psi[0] = psi[0] + weight * dpsi - - psi = comm.pool.bcast([psi[0]]) + psi = psi + weight * dpsi if recover_probe: if probe_options.use_adaptive_moment: beta_probe = cp.mean(cp.stack(beta_probe)) - dprobe = probe_options.probe_update_sum + dprobe = probe_combined_update if probe_options.v is None: - probe_options.v = np.zeros_like( + probe_options.v = cp.zeros_like( dprobe, shape=(3, *dprobe.shape), ) if probe_options.m is None: - probe_options.m = np.zeros_like(dprobe,) + probe_options.m = cp.zeros_like(dprobe,) # ptychoshelves only applies momentum to the main probe mode = 0 ( @@ -358,44 +317,44 @@ def lstsq_grad( v=probe_options.v[..., mode, :, :], m=probe_options.m[..., mode, :, :], mdecay=probe_options.mdecay, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + errors=list(float(np.mean(x)) for x in algorithm_options.costs[-3:]), beta=beta_probe, memory_length=3, ) - probe[0][..., mode, :, :] = probe[0][..., mode, :, :] + d - probe = comm.pool.bcast([probe[0]]) + probe[..., mode, :, :] = probe[..., mode, :, :] + d + + # CONVERSION AREA BELOW ---------------------- - parameters.probe = probe - parameters.psi = psi parameters.scan = scan + parameters.psi = psi + parameters.probe = probe parameters.algorithm_options = algorithm_options - parameters.probe_options = probe_options - parameters.object_options = object_options - parameters.position_options = position_options parameters.eigen_weights = eigen_weights parameters.eigen_probe = eigen_probe + parameters.exitwave_options = exitwave_options + parameters.position_options = position_options + parameters.object_options = object_options + parameters.probe_options = probe_options + return parameters def _update_nearplane( - comm: tike.communicators.Comm, - diff, - probe_update, - m_probe_update, - probe: typing.List[npt.NDArray[cp.csingle]], - eigen_probe: typing.List[npt.NDArray[cp.csingle]], - eigen_weights: typing.List[npt.NDArray[cp.single]], - patches, + diff: npt.NDArray[cp.csingle], + probe_update: npt.NDArray[cp.csingle], + m_probe_update: npt.NDArray[cp.csingle], + probe: npt.NDArray[cp.csingle], + eigen_probe: npt.NDArray[cp.csingle], + eigen_weights: npt.NDArray[cp.single], + patches: npt.NDArray[cp.csingle], batches, *, batch_index: int, num_batch: int, ): m = 0 - if eigen_weights[0] is not None: - - eigen_weights = comm.pool.map( - _get_coefs_intensity, + if eigen_weights is not None: + eigen_weights = _get_coefs_intensity( eigen_weights, diff, probe, @@ -406,23 +365,20 @@ def _update_nearplane( ) # (30) residual probe updates - if eigen_weights[0].shape[-2] > 1: - R = comm.pool.map( - _get_residuals, + if eigen_weights.shape[-2] > 1: + R = _get_residuals( probe_update, m_probe_update, m=m, ) - if eigen_probe[0] is not None and m < eigen_probe[0].shape[-3]: - assert eigen_weights[0].shape[-2] == eigen_probe[0].shape[-4] + 1 - for eigen_index in range(1, eigen_probe[0].shape[-4] + 1): - + if eigen_probe is not None and m < eigen_probe.shape[-3]: + assert eigen_weights.shape[-2] == eigen_probe.shape[-4] + 1 + for eigen_index in range(1, eigen_probe.shape[-4] + 1): ( eigen_probe, eigen_weights, ) = tike.ptycho.probe.update_eigen_probe( - comm, R, eigen_probe, eigen_weights, @@ -435,10 +391,9 @@ def _update_nearplane( m=m, ) - if eigen_index + 1 < eigen_weights[0].shape[-2]: + if eigen_index + 1 < eigen_weights.shape[-2]: # Subtract projection of R onto new probe from R - R = comm.pool.map( - _update_residuals, + R = _update_residuals( R, eigen_probe, batches, diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index e23c92b9..18d6e0be 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -2,11 +2,14 @@ import abc import dataclasses import typing +import copy import numpy as np import numpy.typing as npt import scipy.ndimage +import cupy as cp +import tike.precision from tike.ptycho.object import ObjectOptions from tike.ptycho.position import PositionOptions, check_allowed_positions from tike.ptycho.probe import ProbeOptions @@ -123,7 +126,7 @@ class PtychoParameters(): default_factory=RpieOptions,) """A class containing algorithm specific parameters""" - exitwave_options: typing.Union[ExitWaveOptions, None] = None + exitwave_options: ExitWaveOptions = None """A class containing settings related to exitwave updates.""" probe_options: typing.Union[ProbeOptions, None] = None @@ -191,6 +194,141 @@ def resample( if self.exitwave_options is not None else None, ) + def copy_to_device(self) -> PtychoParameters: + """Copy to the current device.""" + return PtychoParameters( + probe=cp.asarray( + self.probe, + dtype=tike.precision.cfloating, + ), + psi=cp.asarray( + self.psi, + dtype=tike.precision.cfloating, + ), + scan=cp.asarray( + self.scan, + dtype=tike.precision.floating, + ), + eigen_probe=cp.asarray( + self.eigen_probe, + dtype=tike.precision.cfloating, + ) + if self.eigen_probe is not None + else None, + eigen_weights=cp.asarray( + self.eigen_weights, + dtype=tike.precision.floating, + ) + if self.eigen_weights is not None + else None, + algorithm_options=self.algorithm_options, + exitwave_options=self.exitwave_options.copy_to_device() + if self.exitwave_options is not None + else None, + probe_options=self.probe_options.copy_to_device() + if self.probe_options is not None + else None, + object_options=self.object_options.copy_to_device() + if self.object_options is not None + else None, + position_options=self.position_options.copy_to_device() + if self.position_options is not None + else None, + ) + + def copy_to_host(self) -> PtychoParameters: + """Copy to the host.""" + return PtychoParameters( + probe=cp.asnumpy(self.probe), + psi=cp.asnumpy(self.psi), + scan=cp.asnumpy(self.scan), + eigen_probe=cp.asnumpy(self.eigen_probe) + if self.eigen_probe is not None + else None, + eigen_weights=cp.asnumpy(self.eigen_weights) + if self.eigen_weights is not None + else None, + algorithm_options=self.algorithm_options, + exitwave_options=self.exitwave_options.copy_to_host() + if self.exitwave_options is not None + else None, + probe_options=self.probe_options.copy_to_host() + if self.probe_options is not None + else None, + object_options=self.object_options.copy_to_host() + if self.object_options is not None + else None, + position_options=self.position_options.copy_to_host() + if self.position_options is not None + else None, + ) + + @staticmethod + def split( + indices: npt.NDArray[np.intc], + *, + x: PtychoParameters, + ) -> PtychoParameters: + """Return a new PtychoParameters with only the data from the indices""" + return PtychoParameters( + probe=x.probe.astype(tike.precision.cfloating), + psi=x.psi.astype(tike.precision.cfloating), + scan=x.scan[indices].astype(tike.precision.floating), + eigen_probe=x.eigen_probe.astype(tike.precision.cfloating) + if x.eigen_probe is not None + else None, + eigen_weights=x.eigen_weights[indices].astype(tike.precision.floating) + if x.eigen_weights is not None + else None, + algorithm_options=copy.deepcopy(x.algorithm_options), + exitwave_options=x.exitwave_options, + probe_options=x.probe_options, + object_options=x.object_options, + position_options=x.position_options.split(indices) + if x.position_options is not None + else None, + ) + + @staticmethod + def join( + x: typing.Iterable[PtychoParameters], + reorder: npt.NDArray[np.intc], + stripe_start: typing.List[int], + ) -> PtychoParameters: + return PtychoParameters( + probe=x[0].probe, + psi=ObjectOptions.join_psi( + [e.psi for e in x], + probe_width=x[0].probe.shape[-2], + stripe_start=stripe_start, + ), + scan=np.concatenate( + [e.scan for e in x], + axis=0, + )[reorder], + eigen_probe=x[0].eigen_probe, + eigen_weights=np.concatenate( + [e.eigen_weights for e in x], + axis=0, + )[reorder] + if x[0].eigen_weights is not None + else None, + # TODO: costs and times should be joined somehow? + algorithm_options=x[0].algorithm_options, + exitwave_options=x[0].exitwave_options, + # TODO: synchronize probe momentum elsewhere + probe_options=x[0].probe_options, + object_options=ObjectOptions.join( + [e.object_options for e in x], + stripe_start=stripe_start, + probe_width=x[0].probe.shape[-2], + ), + position_options=PositionOptions.join( + [e.position_options for e in x], + reorder, + ), + ) + def _resize_spline(x: np.ndarray, f: float) -> np.ndarray: return scipy.ndimage.zoom( diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index 220cf1a7..878bad82 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -1,7 +1,9 @@ import logging +import typing import cupy as cp import cupyx.scipy.stats +import numpy as np import numpy.typing as npt import tike.communicators @@ -22,12 +24,12 @@ def rpie( - op: tike.operators.Ptycho, - comm: tike.communicators.Comm, - data: typing.List[npt.NDArray], - batches: typing.List[typing.List[npt.NDArray[cp.intc]]], - *, parameters: PtychoParameters, + data: npt.NDArray, + batches: typing.List[npt.NDArray[cp.intc]], + streams: typing.List[cp.cuda.Stream], + *, + op: tike.operators.Ptycho, epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using regularized ptychographical engine. @@ -74,67 +76,54 @@ def rpie( .. seealso:: :py:mod:`tike.ptycho` """ - probe = parameters.probe scan = parameters.scan psi = parameters.psi + probe = parameters.probe algorithm_options = parameters.algorithm_options + eigen_weights = parameters.eigen_weights + eigen_probe = parameters.eigen_probe + measured_pixels = parameters.exitwave_options.measured_pixels exitwave_options = parameters.exitwave_options - probe_options = parameters.probe_options - if probe_options is None: - recover_probe = False - else: - recover_probe = probe_options.recover_probe - position_options = parameters.position_options object_options = parameters.object_options - eigen_probe = parameters.eigen_probe - eigen_weights = parameters.eigen_weights - - if eigen_probe is None: - beigen_probe = [None] * comm.pool.num_workers - else: - beigen_probe = eigen_probe + probe_options = parameters.probe_options + recover_probe = probe_options is not None and epoch >= probe_options.update_start - if eigen_weights is None: - beigen_weights = [None] * comm.pool.num_workers - else: - beigen_weights = eigen_weights + # CONVERSTION AREA ABOVE --------------------------------------- if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation - psi_update_numerator = [None] * comm.pool.num_workers - probe_update_numerator = [None] * comm.pool.num_workers - position_update_numerator = [None] * comm.pool.num_workers - position_update_denominator = [None] * comm.pool.num_workers + psi_update_numerator = None + probe_update_numerator = None + position_update_numerator = None + position_update_denominator = None batch_cost: typing.List[float] = [] for n in order(algorithm_options.num_batch): - ( cost, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - beigen_weights, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_gradients, + eigen_weights, + ) = _get_nearplane_gradients( data, scan, psi, probe, - exitwave_options.measured_pixels, + measured_pixels, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - beigen_probe, - beigen_weights, + eigen_probe, + eigen_weights, batches, - comm.streams, + streams, n=n, op=op, object_options=object_options, @@ -142,16 +131,15 @@ def rpie( recover_probe=recover_probe, position_options=position_options, exitwave_options=exitwave_options, - ))) + ) - batch_cost.append(comm.Allreduce_mean(cost, axis=None).get()) + batch_cost.append(cost) if algorithm_options.batch_method != 'compact': ( psi, probe, ) = _update( - comm, psi, probe, psi_update_numerator, @@ -161,8 +149,8 @@ def rpie( recover_probe, algorithm_options, ) - psi_update_numerator = [None] * comm.pool.num_workers - probe_update_numerator = [None] * comm.pool.num_workers + psi_update_numerator = None + probe_update_numerator = None algorithm_options.costs.append(batch_cost) @@ -170,8 +158,7 @@ def rpie( ( scan, position_options, - ) = (list(a) for a in zip(*comm.pool.map( - _update_position, + ) = _update_position( scan, position_options, position_update_numerator, @@ -179,14 +166,13 @@ def rpie( max_shift=probe[0].shape[-1] * 0.1, alpha=algorithm_options.alpha, epoch=epoch, - ))) + ) if algorithm_options.batch_method == 'compact': ( psi, probe, ) = _update( - comm, psi, probe, psi_update_numerator, @@ -195,23 +181,27 @@ def rpie( probe_options, recover_probe, algorithm_options, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + errors=[float(np.mean(x)) for x in algorithm_options.costs[-3:]], ) if eigen_weights is not None: - eigen_weights = comm.pool.map( - _normalize_eigen_weights, - beigen_weights, + eigen_weights = _normalize_eigen_weights( + eigen_weights, ) - parameters.probe = probe - parameters.psi = psi + # CONVERSION AREA BELOW ---------------------- + parameters.scan = scan + parameters.psi = psi + parameters.probe = probe parameters.algorithm_options = algorithm_options - parameters.probe_options = probe_options - parameters.object_options = object_options - parameters.position_options = position_options parameters.eigen_weights = eigen_weights + parameters.eigen_probe = eigen_probe + parameters.exitwave_options = exitwave_options + parameters.position_options = position_options + parameters.object_options = object_options + parameters.probe_options = probe_options + return parameters @@ -224,7 +214,6 @@ def _normalize_eigen_weights(eigen_weights): def _update( - comm: tike.communicators.Comm, psi: npt.NDArray[cp.csingle], probe: npt.NDArray[cp.csingle], psi_update_numerator: npt.NDArray[cp.csingle], @@ -233,19 +222,19 @@ def _update( probe_options: ProbeOptions, recover_probe: bool, algorithm_options: RpieOptions, - errors: typing.Union[None, typing.List[float]] = None, -): + errors: typing.Union[None, npt.NDArray] = None, +) -> typing.Tuple[npt.NDArray[cp.csingle], npt.NDArray[cp.csingle]]: if object_options: - psi_update_numerator = comm.Allreduce_reduce_gpu( - psi_update_numerator)[0] dpsi = psi_update_numerator deno = ( - (1 - algorithm_options.alpha) * object_options.preconditioner[0] + - algorithm_options.alpha * object_options.preconditioner[0].max( + (1 - algorithm_options.alpha) * object_options.preconditioner + + algorithm_options.alpha + * object_options.preconditioner.max( axis=(-2, -1), keepdims=True, - )) - psi[0] = psi[0] + dpsi / deno + ) + ) + psi = psi + dpsi / deno if object_options.use_adaptive_moment: if errors: ( @@ -272,29 +261,31 @@ def _update( vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) - psi[0] = psi[0] + dpsi / deno - psi = comm.pool.bcast([psi[0]]) + psi = psi + dpsi / deno if recover_probe: - - probe_update_numerator = comm.Allreduce_reduce_gpu( - probe_update_numerator)[0] b0 = tike.ptycho.probe.finite_probe_support( - probe[0], + probe, p=probe_options.probe_support, radius=probe_options.probe_support_radius, degree=probe_options.probe_support_degree, ) - b1 = probe_options.additional_probe_penalty * cp.linspace( - 0, 1, probe[0].shape[-3], dtype='float32')[..., None, None] - dprobe = (probe_update_numerator - (b1 + b0) * probe[0]) + b1 = ( + probe_options.additional_probe_penalty + * cp.linspace(0, 1, probe.shape[-3], dtype="float32")[..., None, None] + ) + dprobe = probe_update_numerator - (b1 + b0) * probe deno = ( - (1 - algorithm_options.alpha) * probe_options.preconditioner[0] + - algorithm_options.alpha * probe_options.preconditioner[0].max( + (1 - algorithm_options.alpha) * probe_options.preconditioner + + algorithm_options.alpha + * probe_options.preconditioner.max( axis=(-2, -1), keepdims=True, - ) + b0 + b1) - probe[0] = probe[0] + dprobe / deno + ) + + b0 + + b1 + ) + probe = probe + dprobe / deno if probe_options.use_adaptive_moment: # ptychoshelves only applies momentum to the main probe mode = 0 @@ -323,8 +314,7 @@ def _update( vdecay=probe_options.vdecay, mdecay=probe_options.mdecay, ) - probe[0] = probe[0] + dprobe / deno - probe = comm.pool.bcast([probe[0]]) + probe = probe + dprobe / deno return psi, probe @@ -341,7 +331,7 @@ def _get_nearplane_gradients( position_update_denominator: typing.Union[None, npt.NDArray], eigen_probe: typing.Union[None, npt.NDArray], eigen_weights: typing.Union[None, npt.NDArray], - batches: typing.List[typing.List[int]], + batches: typing.List[npt.NDArray[np.intc]], streams: typing.List[cp.cuda.Stream], *, n: int, @@ -351,10 +341,11 @@ def _get_nearplane_gradients( recover_probe: bool, position_options: typing.Union[None, PositionOptions], exitwave_options: ExitWaveOptions, -) -> typing.List[npt.NDArray]: - - cost = 0.0 - count = 1.0 / len(batches[n]) +) -> typing.Tuple[ + float, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray | None +]: + cost: float = 0.0 + count: float = 1.0 / len(batches[n]) psi_update_numerator = cp.zeros_like( psi) if psi_update_numerator is None else psi_update_numerator probe_update_numerator = cp.zeros_like( @@ -544,7 +535,7 @@ def keep_some_args_constant( ) return ( - cost, + float(cost), psi_update_numerator, probe_update_numerator, position_update_numerator, @@ -559,10 +550,10 @@ def _update_position( position_update_numerator: npt.NDArray, position_update_denominator: npt.NDArray, *, - alpha=0.05, - max_shift=1, - epoch=0, -): + alpha: float = 0.05, + max_shift: float = 1.0, + epoch: int = 0, +) -> typing.Tuple[cp.ndarray, PositionOptions]: if epoch < position_options.update_start: return scan, position_options diff --git a/tests/communicators/test_pool.py b/tests/communicators/test_pool.py index 9f4234fe..ed3ac4b8 100644 --- a/tests/communicators/test_pool.py +++ b/tests/communicators/test_pool.py @@ -140,6 +140,35 @@ def test_reduce_mean(self): # print(result.shape, type(truth)) self.xp.testing.assert_array_equal(result, truth) + def test_swap_edges(self): + + def init(i): + return self.xp.ones((1, 4 * self.pool.num_workers, 1), dtype=int) * i + + x = self.pool.map(init, list(range(self.pool.num_workers))) + + edges = np.arange(self.pool.num_workers, dtype=int) * 4 + overlap = 3 + + x1 = self.pool.swap_edges( + x, + overlap=overlap, + edges=edges, + ) + + print() + for i, element in enumerate(x1): + print(element.flatten()) + truth = self.xp.ones((1, 4 * self.pool.num_workers, 1), dtype=int) * i + if i > 0: + truth[..., edges[i] : edges[i] + overlap, :] = i - 1 + if i < len(x1) - 1: + truth[..., edges[i + 1] : (edges[i + 1] + overlap), :] = i + 1 + self.xp.testing.assert_array_equal( + element, + truth, + ) + class TestSoloThreadPool(TestThreadPool): diff --git a/tests/ptycho/test_position.py b/tests/ptycho/test_position.py index 12482e54..b495ec78 100644 --- a/tests/ptycho/test_position.py +++ b/tests/ptycho/test_position.py @@ -27,6 +27,7 @@ def test_position_join(N=245, num_batch=11): assert np.amax(indices) == N - 1 np.random.shuffle(indices) batches = np.array_split(indices, num_batch) + reorder = np.argsort(np.concatenate(batches)) opts = tike.ptycho.PositionOptions( scan, @@ -35,19 +36,17 @@ def test_position_join(N=245, num_batch=11): optsb = [opts.split(b) for b in batches] - # Copies non-array params into new object - new_opts = optsb[0].split([]) + joined = PositionOptions.join(optsb, reorder=reorder) - for b, i in zip(optsb, batches): - new_opts = new_opts.join(b, i) + assert joined is not None np.testing.assert_array_equal( - new_opts.initial_scan, + joined.initial_scan, opts.initial_scan, ) np.testing.assert_array_equal( - new_opts._momentum, + joined._momentum, opts._momentum, ) diff --git a/tests/ptycho/test_probe.py b/tests/ptycho/test_probe.py index 5c1fa7cc..5e40b3c3 100644 --- a/tests/ptycho/test_probe.py +++ b/tests/ptycho/test_probe.py @@ -27,29 +27,21 @@ def test_eigen_probe(self): high = 21 posi = 53 eigen = 1 - comm = Comm(2) - R = comm.pool.bcast([np.random.rand(*leading, posi, 1, 1, wide, high)]) - eigen_probe = comm.pool.bcast( - [np.random.rand(*leading, 1, eigen, 1, wide, high)]) + R = np.random.rand(*leading, posi, 1, 1, wide, high) + eigen_probe = np.random.rand(*leading, 1, eigen, 1, wide, high) weights = np.random.rand(*leading, posi, eigen + 1, 1) weights -= np.mean(weights, axis=-3, keepdims=True) - weights = comm.pool.bcast([weights]) - patches = comm.pool.bcast( - [np.random.rand(*leading, posi, 1, 1, wide, high)]) - diff = comm.pool.bcast( - [np.random.rand(*leading, posi, 1, 1, wide, high)]) + patches = np.random.rand(*leading, posi, 1, 1, wide, high) + diff = np.random.rand(*leading, posi, 1, 1, wide, high) new_probe, new_weights = tike.ptycho.probe.update_eigen_probe( - comm=comm, R=R, eigen_probe=eigen_probe, weights=weights, patches=patches, diff=diff, - batches=[[ - list(range(53)), - ]], + batches=[list(range(53))], batch_index=0, c=1, m=0,