Skip to content

Commit

Permalink
fix: use shared memory for the smaps.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Dec 6, 2023
1 parent dea99a4 commit 2f2c37f
Showing 1 changed file with 77 additions and 55 deletions.
132 changes: 77 additions & 55 deletions src/simfmri/handlers/acquisition/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

from multiprocessing import shared_memory
from contextlib import contextmanager
import logging
import warnings
from typing import Any, Callable, Mapping, Generator
Expand All @@ -12,7 +13,7 @@
from tqdm.auto import tqdm

from simfmri.simulation import SimData

from numpy.typing import DTypeLike
from ._tools import TrajectoryGeneratorType

# from mrinufft import get_operator
Expand Down Expand Up @@ -198,6 +199,42 @@ def acq_cartesian(
return kdata, kmask


def _init_sm(
name: str, shape: tuple[int, ...], dtype: DTypeLike
) -> shared_memory.SharedMemory:
"""Initialize a shared memory buffer."""
return shared_memory.SharedMemory(
name=name,
create=True,
size=np.prod(shape) * np.dtype(dtype).itemsize,
)


@contextmanager
def shm_manager(
name: str, shape: tuple[int, ...], dtype: DTypeLike, unlink: bool = False
) -> Generator[np.ndarray]:
"""Context manager for shared memory."""
try:
shm = shared_memory.SharedMemory(
name=name,
create=False,
)
except FileNotFoundError:
arr = None
else:
arr = np.ndarray(shape, buffer=shm.buf, dtype=dtype)

# properly close stuff
try:
yield arr
finally:
del arr
shm.close()
if unlink:
shm.unlink()


def acq_noncartesian(
sim: SimData,
trajectory_gen: TrajectoryGeneratorType,
Expand All @@ -211,18 +248,20 @@ def acq_noncartesian(
n_samples = np.prod(test_traj.shape[:-1])
dim = test_traj.shape[-1]

kdata_infos = ((n_kspace_frame, sim.n_coils, n_samples), np.complex64)
shm_kdata = shared_memory.SharedMemory(
name="kdata",
create=True,
size=np.prod(kdata_infos[0]) * np.dtype(kdata_infos[1]).itemsize,
)
kmask_infos = ((n_kspace_frame, n_samples, dim), np.float32)
shm_kmask = shared_memory.SharedMemory(
name="kmask",
create=True,
size=np.prod(kmask_infos[0]) * np.dtype(kmask_infos[1]).itemsize,
)
# Allocate kspace data, kspace mask and smaps in shared memory.
shm_infos = {
"kdata": ((n_kspace_frame, sim.n_coils, n_samples), np.complex64),
"kmask": ((n_kspace_frame, n_samples, dim), np.float32),
"smaps": ((sim.n_coils, *sim.shape), np.complex64),
}
_init_sm("kdata", *shm_infos["kdata"])
_init_sm("kmask", *shm_infos["kmask"])
if sim.smaps:
shm_smaps = _init_sm("smaps", *shm_infos["smaps"])
smaps = np.ndarray(
shm_infos["smaps"][0], buffer=shm_smaps.buf, dtype=shm_infos["smaps"][1]
)
smaps[:] = sim.smaps

nufft_backend = kwargs.pop("backend")
logger.debug("Using backend %s", nufft_backend)
Expand All @@ -236,14 +275,7 @@ def acq_noncartesian(
density=False,
backend_name=nufft_backend,
)
if "gpunufft" in nufft_backend:
logger.debug("Using gpunufft, pinning smaps")
from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps

op_kwargs["pinned_smaps"] = make_pinned_smaps(sim.smaps)
op_kwargs["smaps"] = None
else:
op_kwargs["smaps"] = sim.smaps
scheduler = kspace_bulk_shot(trajectory_gen, sim.n_frames, n_shot_sim_frame)
with Parallel(n_jobs=n_jobs, backend="loky", mmap_mode="r") as par:
par(
Expand All @@ -252,8 +284,7 @@ def acq_noncartesian(
shot_batch,
shot_pos,
op_kwargs,
kdata_infos,
kmask_infos,
shm_infos,
)
for sim_frame, shot_batch, shot_pos in tqdm(work_generator(sim, scheduler))
)
Expand All @@ -263,19 +294,12 @@ def acq_noncartesian(

get_reusable_executor().shutdown(wait=True)

kdata_ = np.ndarray(kdata_infos[0], buffer=shm_kdata.buf, dtype=kdata_infos[1])
kmask_ = np.ndarray(kmask_infos[0], buffer=shm_kmask.buf, dtype=kmask_infos[1])

kdata = np.copy(kdata_)
kmask = np.copy(kmask_)
del kdata_
del kmask_

shm_kdata.close()
shm_kmask.close()
shm_kdata.unlink()
shm_kmask.unlink()

with (
shm_manager("kdata", *shm_infos["kdata"], unlink=True) as kdata_,
shm_manager("kmask", *shm_infos["kmask"], unlink=True) as kmask_,
):
kdata = np.copy(kdata_)
kmask = np.copy(kmask_)
return kdata, kmask


Expand All @@ -291,32 +315,30 @@ def _single_worker(
shot_batch: np.ndarray,
shot_pos: tuple[int, int],
op_kwargs: Mapping[str, Any],
kdata_infos: tuple[tuple[int], np.Dtype],
kmask_infos: tuple[tuple[int], np.Dtype],
shm_infos: Mapping[str, tuple[tuple[int], np.DtypeLike]],
) -> None:
"""Perform a shot acquisition."""
with warnings.catch_warnings():

with (
warnings.catch_warnings(),
shm_manager("kdata", *shm_infos["kdata"]) as kdata_,
shm_manager("kmask", *shm_infos["kmask"]) as kmask_,
shm_manager("smaps", *shm_infos["smaps"]) as smaps,
):
warnings.filterwarnings(
"ignore",
"Samples will be rescaled to .*",
category=UserWarning,
module="mrinufft",
)
fourier_op = get_operator(samples=shot_batch, **op_kwargs)
kspace = fourier_op.op(sim_frame)
L = shot_batch.shape[1]

shm_kdata = shared_memory.SharedMemory(name="kdata", create=False)
shm_kmask = shared_memory.SharedMemory(name="kmask", create=False)
if "gpunufft" in op_kwargs["backend_name"]:
op_kwargs["pinned_smaps"] = smaps
smaps = None

kdata_ = np.ndarray(kdata_infos[0], buffer=shm_kdata.buf, dtype=kdata_infos[1])
kmask_ = np.ndarray(kmask_infos[0], buffer=shm_kmask.buf, dtype=kmask_infos[1])

for i, (k, s) in enumerate(shot_pos):
kdata_[k, :, s * L : (s + 1) * L] = kspace[..., i * L : (i + 1) * L]
kmask_[k, s * L : (s + 1) * L] = shot_batch[i]
fourier_op = get_operator(samples=shot_batch, smaps=smaps, **op_kwargs)
kspace = fourier_op.op(sim_frame)
L = shot_batch.shape[1]

del kdata_
del kmask_
shm_kdata.close()
shm_kmask.close()
# write to share memory shots location and values.
for i, (k, s) in enumerate(shot_pos):
kdata_[k, :, s * L : (s + 1) * L] = kspace[..., i * L : (i + 1) * L]
kmask_[k, s * L : (s + 1) * L] = shot_batch[i]

0 comments on commit 2f2c37f

Please sign in to comment.