Skip to content

Commit

Permalink
Merge pull request #322 from carterbox/multi-gpu-new
Browse files Browse the repository at this point in the history
REF: Reorganize how multi-device parallelism is implemented
  • Loading branch information
carterbox authored Jul 17, 2024
2 parents 844d902 + a7a40ec commit 73242b0
Show file tree
Hide file tree
Showing 16 changed files with 1,260 additions and 959 deletions.
72 changes: 36 additions & 36 deletions src/tike/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def _split_gpu(
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
"""Return x[m] as a CuPy array on the current device."""
return cp.asarray(x[m], dtype=dtype)


Expand All @@ -25,6 +26,7 @@ def _split_host(
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
"""Return x[m] as a NumPy array."""
return np.asarray(x[m], dtype=dtype)


Expand All @@ -33,6 +35,7 @@ def _split_pinned(
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
"""Return x[m] as a CuPy pinned host memory array."""
pinned = cupyx.empty_pinned(shape=(len(m), *x.shape[1:]), dtype=dtype)
pinned[...] = x[m]
return pinned
Expand Down Expand Up @@ -171,22 +174,24 @@ def by_scan_stripes(


def by_scan_stripes_contiguous(
*args,
pool: tike.communicators.ThreadPool,
shape: typing.Tuple[int],
dtype: typing.List[npt.DTypeLike],
destination: typing.List[str],
scan: npt.NDArray[np.float32],
fly: int = 1,
batch_method,
batch_method: typing.Literal[
"compact", "wobbly_center", "wobbly_center_random_bootstrap"
],
num_batch: int,
) -> typing.Tuple[typing.List[npt.NDArray],
typing.List[typing.List[npt.NDArray]]]:
"""Split data by into stripes and create contiguously ordered batches.
Divide the field of view into one stripe per devices; within each stripe,
create batches according to the batch_method loading the batches into
contiguous blocks in device memory.
) -> typing.Tuple[
typing.List[npt.NDArray],
typing.List[typing.List[npt.NDArray]],
typing.List[int],
]:
"""Return the indices that will split `scan` into 2D stripes of equal count
and create contiguously ordered batches within those stripes.
Divide the field of view into one stripe per worker in `pool`; within each
stripe, create batches according to the batch_method loading the batches
into contiguous blocks in device memory.
Parameters
----------
Expand All @@ -206,14 +211,14 @@ def by_scan_stripes_contiguous(
Returns
-------
order : List[array[int]]
The locations of the inputs in the original arrays.
For each worker in pool, the indices of the data
batches : List[List[array[int]]]
The locations of the elements of each batch
scan : List[array[float32]]
The divided 2D coordinates of the scan positions.
args : List[array[float32]] or None
Each input divided into regions or None if arg was None.
For each worker in pool, for each batch, the indices of the elements of
each batch
stripe_start : List[int]
The coorinates of the leading edge of each stripe along the 0th
dimension in the scan coordinates. e.g the minimum coordinate of the
scan positions in each stripe.
"""
if len(shape) != 2:
raise ValueError('The grid shape must have two dimensions.')
Expand All @@ -229,6 +234,7 @@ def by_scan_stripes_contiguous(
x=scan,
dtype=scan.dtype,
)
stripe_start = [int(np.floor(np.min(x[:, 0]))) for x in split_scan]
batches_noncontiguous: typing.List[typing.List[npt.NDArray]] = pool.map(
getattr(tike.cluster, batch_method),
split_scan,
Expand All @@ -247,26 +253,13 @@ def by_scan_stripes_contiguous(
batch_breaks,
))

split_args = []
for arg, t, dest in zip([scan, *args], dtype, destination):
if arg is None:
split_args.append(None)
else:
split_args.append(
pool.map(
_split_gpu if dest == 'gpu' else _split_pinned,
map_to_gpu_contiguous,
x=arg,
dtype=t,
))

if __debug__:
for device in batches_contiguous:
assert len(device) == num_batch, (
f"There should be {num_batch} batches, found {len(device)}"
)

return (map_to_gpu_contiguous, batches_contiguous, *split_args)
return (map_to_gpu_contiguous, batches_contiguous, stripe_start)


def stripes_equal_count(
Expand Down Expand Up @@ -306,7 +299,10 @@ def stripes_equal_count(
)


def wobbly_center(population, num_cluster):
def wobbly_center(
population: npt.ArrayLike,
num_cluster: int,
) -> typing.List[npt.NDArray]:
"""Return the indices that divide population into heterogenous clusters.
Uses a contrarian approach to clustering by maximizing the heterogeneity
Expand Down Expand Up @@ -382,7 +378,7 @@ def wobbly_center(population, num_cluster):


def wobbly_center_random_bootstrap(
population,
population: npt.ArrayLike,
num_cluster: int,
boot_fraction: float = 0.95,
) -> typing.List[npt.NDArray]:
Expand Down Expand Up @@ -466,7 +462,11 @@ def wobbly_center_random_bootstrap(
return [cp.asnumpy(xp.flatnonzero(labels == c)) for c in range(num_cluster)]


def compact(population, num_cluster, max_iter=500):
def compact(
population: npt.ArrayLike,
num_cluster: int,
max_iter: int = 500,
) -> typing.List[npt.NDArray]:
"""Return the indices that divide population into compact clusters.
Uses an approach that is inspired by the naive k-means algorithm, but it
Expand Down
24 changes: 19 additions & 5 deletions src/tike/communicators/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
from .pool import ThreadPool


def _init_streams():
return [cp.cuda.Stream() for _ in range(2)]


class Comm:
"""A Ptychography communicator.
Expand Down Expand Up @@ -47,7 +43,6 @@ def __init__(
self.use_mpi = True
self.mpi = mpi()
self.pool = pool(gpu_count)
self.streams = self.pool.map(_init_streams)

def __enter__(self):
self.mpi.__enter__()
Expand Down Expand Up @@ -139,3 +134,22 @@ def Allreduce(
buf.append(
self.mpi.Allreduce(src[self.pool.workers.index(worker)]))
return buf

def swap_edges(
self,
x: typing.List[cp.ndarray],
overlap: int,
edges: typing.List[int],
) -> typing.List[cp.ndarray]:
"""Swap the region of each x with its neighbor around the given edges.
Given iterable x, a list of ND arrays; edges, the coordinates in x
along dimension -2; and overlap, the width of the region to swap around
the edge; trade [..., edge:(edge + overlap), :] between neighbors.
"""
# FIXME: Swap edges between MPI nodes
return self.pool.swap_edges(
x=x,
overlap=overlap,
edges=edges,
)
76 changes: 73 additions & 3 deletions src/tike/communicators/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def __init__(
self.num_workers) if self.num_workers > 1 else NoPoolExecutor(
self.num_workers)

def f(worker):
with self.Device(worker):
return [cp.cuda.Stream() for _ in range(2)]

self.streams = list(self.executor.map(f, self.workers))

def __enter__(self):
if self.workers[0] != cp.cuda.Device().id:
raise ValueError(
Expand Down Expand Up @@ -397,10 +403,74 @@ def map(
) -> list:
"""ThreadPoolExecutor.map, but wraps call in a cuda.Device context."""

def f(worker, *args):
def f(worker, streams, *args):
with self.Device(worker):
return func(*args, **kwargs)
with streams[1]:
return func(*args, **kwargs)

workers = self.workers if workers is None else workers

return list(self.executor.map(f, workers, *iterables))
return list(self.executor.map(f, workers, self.streams, *iterables))

def swap_edges(
self,
x: typing.List[cp.ndarray],
overlap: int,
edges: typing.List[int],
):
"""Swap [..., edge:(edge + overlap), :] between neighbors in-place
For example, given overlap=3 and edges=[0, 4, 8, 12], the following
swap would be returned:
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0]]
[[1 1 1 1 0 0 0 1 2 2 2 1 1 1 1 1]]
[[2 2 2 2 2 2 2 2 1 1 1 2 3 3 3 2]]
[[3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 3]]
```
Note that the minimum swapped region is 1 wide.
"""
if overlap < 1:
msg = f"Overlap for swap_edges cannot be less than 1: {overlap}"
raise ValueError(msg)
for i in range(self.num_workers - 1):
lo = edges[i + 1]
hi = lo + overlap
temp0 = self._copy_to(x[i][..., lo:hi, :], self.workers[i + 1])
temp1 = self._copy_to(x[i + 1][..., lo:hi, :], self.workers[i])
with self.Device(self.workers[i]):
rampu = cp.linspace(
0.0,
1.0,
overlap + 2,
endpoint=True,
)[1:-1][..., None]
rampd = cp.linspace(
1.0,
0.0,
overlap + 2,
endpoint=True,
)[1:-1][..., None]
x[i][..., lo:hi, :] = rampd * x[i][..., lo:hi, :] + rampu * temp1
with self.Device(self.workers[i + 1]):
rampu = cp.linspace(
0.0,
1.0,
overlap + 2,
endpoint=True,
)[1:-1][..., None]
rampd = cp.linspace(
1.0,
0.0,
overlap + 2,
endpoint=True,
)[1:-1][..., None]
x[i + 1][..., lo:hi, :] = (
rampd * temp0 + rampu * x[i + 1][..., lo:hi, :]
)
return x
6 changes: 3 additions & 3 deletions src/tike/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ def conjugate_gradient(


def fit_line_least_squares(
y: typing.List[float],
x: typing.List[float],
) -> typing.Tuple[float, float]:
y: npt.NDArray[np.floating],
x: npt.NDArray[np.floating],
) -> typing.Tuple[np.floating, np.floating]:
"""Return the `slope`, `intercept` pair that best fits `y`, `x` to a line.
y = slope * x + intercept
Expand Down
36 changes: 24 additions & 12 deletions src/tike/ptycho/exitwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
just free space propagation to the detector.
"""

from __future__ import annotations

import copy
Expand Down Expand Up @@ -67,7 +68,7 @@ class ExitWaveOptions:
exitwave updates in Fourier space. `1.0` for no scaling.
"""

propagation_normalization: str = 'ortho'
propagation_normalization: str = "ortho"
"""Choose the scaling of the FFT in the forward model:
"ortho" - the forward and adjoint operations are divided by sqrt(n)
Expand All @@ -78,19 +79,29 @@ class ExitWaveOptions:
"""

def copy_to_device(self, comm) -> ExitWaveOptions:
def copy_to_device(self) -> ExitWaveOptions:
"""Copy to the current GPU memory."""
options = copy.copy(self)
if self.measured_pixels is not None:
options.measured_pixels = comm.pool.bcast([self.measured_pixels])
return options
return ExitWaveOptions(
measured_pixels=cp.asarray(self.measured_pixels, dtype=bool),
noise_model=self.noise_model,
propagation_normalization=self.propagation_normalization,
step_length_start=self.step_length_start,
step_length_usemodes=self.step_length_usemodes,
step_length_weight=self.step_length_weight,
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
)

def copy_to_host(self) -> ExitWaveOptions:
"""Copy to the host CPU memory."""
options = copy.copy(self)
if self.measured_pixels is not None:
options.measured_pixels = cp.asnumpy(self.measured_pixels[0])
return options
return ExitWaveOptions(
measured_pixels=cp.asnumpy(self.measured_pixels),
noise_model=self.noise_model,
propagation_normalization=self.propagation_normalization,
step_length_start=self.step_length_start,
step_length_usemodes=self.step_length_usemodes,
step_length_weight=self.step_length_weight,
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
)

def resample(self, factor: float) -> ExitWaveOptions:
"""Return a new `ExitWaveOptions` with the parameters rescaled."""
Expand All @@ -103,8 +114,9 @@ def resample(self, factor: float) -> ExitWaveOptions:
self.measured_pixels,
int(self.measured_pixels.shape[-1] * factor),
),
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
propagation_normalization=self.propagation_normalization )
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
propagation_normalization=self.propagation_normalization,
)


def poisson_steplength_all_modes(
Expand Down
Loading

0 comments on commit 73242b0

Please sign in to comment.