From 860d89373f0c4aee5883360760be6caddf989c42 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 5 Dec 2023 15:30:22 +0100 Subject: [PATCH] feat: inline import for faster discoverability. --- src/simfmri/reconstructors/pysap.py | 64 ++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/src/simfmri/reconstructors/pysap.py b/src/simfmri/reconstructors/pysap.py index 87a861c..ef6b636 100644 --- a/src/simfmri/reconstructors/pysap.py +++ b/src/simfmri/reconstructors/pysap.py @@ -4,19 +4,13 @@ and returns a reconstructed fMRI array. """ from __future__ import annotations -from typing import Literal +from typing import Literal, Protocol import logging import warnings import numpy as np - -from fmri.operators.fourier import FFT_Sense, RepeatOperator, PooledgpuNUFFTSpaceFourier -from fmri.operators.fourier import CartesianSpaceFourier, SpaceFourierBase from modopt.opt.linear import LinearParent from modopt.opt.proximity import ProximityParent -from mrinufft.operators import get_operator -from mrinufft.operators.stacked import traj3d2stacked -from mrinufft.trajectories.density import voronoi from .base import BaseReconstructor from simfmri.simulation import SimData @@ -24,7 +18,32 @@ logger = logging.getLogger(__name__) -def _get_stacked_operator(backend: str, sim: SimData) -> RepeatOperator: +class SpaceFourier(Protocol): + """Fourier operator interface.""" + + n_frames: int + shape: tuple[int] + n_coils: int + uses_sense: bool + + def op(self, x: np.ndarray) -> np.ndarray: + """Apply the Fourier operator.""" + ... + + def adj_op(self, x: np.ndarray) -> np.ndarray: + """Apply the adjoint of the Fourier operator.""" + ... + + +def _get_stacked_operator(backend: str, sim: SimData) -> SpaceFourier: + from mrinufft.operators.stacked import traj3d2stacked + from mrinufft.trajectories.density import voronoi + from mrinufft.operators import get_operator + + from fmri.operators.fourier import ( + RepeatOperator, + ) + nufft_backend = backend.split("-")[1] frame_ops = [] Ns = sim.extra_infos["traj_params"]["n_samples"] @@ -60,10 +79,19 @@ def _get_stacked_operator(backend: str, sim: SimData) -> RepeatOperator: def get_fourier_operator( sim: SimData, cartesian_repeat: bool = False, **kwargs: None -) -> RepeatOperator | CartesianSpaceFourier: +) -> SpaceFourier: """Return a Fourier operator for the given simulation.""" kwargs = kwargs.copy() if kwargs is not None else {} + from fmri.operators.fourier import CartesianSpaceFourier, SpaceFourierBase + from mrinufft.operators import get_operator + + from fmri.operators.fourier import ( + FFT_Sense, + RepeatOperator, + PooledgpuNUFFTSpaceFourier, + ) + density = True backend = sim.extra_infos.get("operator", "fft") logger.info(f"fourier backend is {backend}") @@ -197,9 +225,7 @@ def setup(self, sim: SimData) -> None: self.fourier_op, space_linear_op, space_prox_op, optimizer="pogm" ) - def reconstruct( - self, sim: SimData, fourier_op: SpaceFourierBase | None = None - ) -> np.ndarray: + def reconstruct(self, sim: SimData, fourier_op: None = None) -> np.ndarray: """Reconstruct with Sequential.""" if fourier_op is not None: self.fourier_op = fourier_op @@ -236,7 +262,7 @@ def __init__( time_linear_op: LinearParent = None, time_prox_op: ProximityParent = None, space_prox_op: ProximityParent = None, - fourier_op: SpaceFourierBase = None, + fourier_op: SpaceFourier = None, ): super().__init__() self.lambda_l = lambda_l @@ -265,11 +291,11 @@ def setup(self, sim: SimData) -> None: if self.fourier_op is None: self.fourier_op = get_fourier_operator(sim, cartesian_repeat=False) - logger.debug(f"Space Fourier operator initialized") + logger.debug("Space Fourier operator initialized") if self.time_linear_op is None: self.time_linear_op = TimeFourier(time_axis=0) - logger.debug(f"Time Fourier operator initialized") + logger.debug("Time Fourier operator initialized") if self.lambda_s == "sure": adj_data = self.fourier_op.adj_op(sim.kspace_data) sure_thresh = np.zeros(np.prod(adj_data.shape[1:])) @@ -285,12 +311,12 @@ def setup(self, sim: SimData) -> None: self.time_linear_op, self.lambda_s, thresh_type="soft" ) - logger.debug(f"Prox Time operator initialized") + logger.debug("Prox Time operator initialized") if self.space_prox_op is None: self.space_prox_op = FlattenSVT( self.lambda_l, initial_rank=10, thresh_type="soft-rel" ) - logger.debug(f"Prox Space operator initialized") + logger.debug("Prox Space operator initialized") self.reconstructor = LowRankPlusSparseReconstructor( self.fourier_op, @@ -298,10 +324,10 @@ def setup(self, sim: SimData) -> None: time_prox_op=self.time_prox_op, cost="auto", ) - logger.debug(f"Reconstructor initialized") + logger.debug("Reconstructor initialized") def reconstruct( - self, sim: SimData, fourier_op: SpaceFourierBase | None = None + self, sim: SimData, fourier_op: SpaceFourier | None = None ) -> np.ndarray: """Reconstruct using LowRank+Sparse Method.""" if fourier_op is not None: