Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Feb 28, 2024
1 parent dff6580 commit 0c5e838
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 43 deletions.
6 changes: 3 additions & 3 deletions examples/example_vds_simrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@
# Sequential Reconstruction
# ~~~~~~~~~~~~~~~~~~~~~~~~~

seq_data = SequentialReconstructor(max_iter_per_frame=20, threshold="sure").reconstruct(
sim
)
seq_data = SequentialReconstructor(
max_iter_per_frame=20, threshold="sure", compute_backend="numpy"
).reconstruct(sim)
fig4 = tile_view(abs(seq_data), samples=0.1, axis=0)

# %%
Expand Down
16 changes: 13 additions & 3 deletions src/conf/scenario1.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# This files contains the configuration to reproduce the scenario 1 of the Snkf paper.

defaults:
- handlers:
- phantom-brainweb
- activation-block
- noise-gaussian
- acquisition-vds
- reconstructors: adjoint
- _self_

force_sim: false
cache_dir: ${oc.env:PWD}/cache
result_dir: results
Expand All @@ -16,9 +25,10 @@ sim_params:
n_coils: 1
rng: 19980408
lazy: True

handlers:
phantom-brainweb:
subject_id: 5
sub_id: 5
bbox: [0.225,-0.07, 0.06, -0.055, null, null]
brainweb_folder: ${cache_dir}/brainweb
res: [3.0, 3.0,2.81 ]
Expand All @@ -29,14 +39,14 @@ handlers:
duration: 300
bold_strength: 0.05
noise-gaussian:
snr: 100
snr: 30
acquisition-vds:
shot_time_ms: 50
acs: 1
accel: 1
accel_axis: 0
constant: true
direction: top-down
order: TOP_DOWN
smaps: false
n_jobs: 5

Expand Down
2 changes: 1 addition & 1 deletion src/conf/scenario2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ handlers:
duration: 300
bold_strength: 0.02
noise-gaussian:
snr: 100
snr: 30
acquisition-sos:
shot_time_ms: 50
acsz: 0.125
Expand Down
4 changes: 2 additions & 2 deletions src/conf/scenario3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ handlers:
duration: 300
bold_strength: 0.02
noise-gaussian:
snr: 100
snr: 30
acquisition-generic-noncartesian:
shot_time_ms: 50
n_jobs: 4
traj_files: ${oc.env:PWD}/cache/trajectory/sparkling3d-48-2688x5.bin
traj_files: ${cache_dir}/trajectory/sparkling3d-48-2688x5.bin
traj_osf: 5
smaps: true
backend: "cufinufft"
Expand Down
2 changes: 1 addition & 1 deletion src/snkf/handlers/acquisition/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def acq_cartesian(
if sim.n_coils == 1:
kdata[kk, 0, mask] = kdata_t[i][ii]
else:
kdata[kk, 0, mask] = kdata_t[i][ii]
kdata[kk, :, mask] = kdata_t[i][ii].T
kmask[kk] += mask
return kdata, kmask

Expand Down
6 changes: 5 additions & 1 deletion src/snkf/reconstructors/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base Interfaces for the reconstructors."""

import logging
from dataclasses import field
import numpy as np
from typing import Protocol, Any, ClassVar
from snkf.simulation import SimData
Expand Down Expand Up @@ -38,7 +39,10 @@ class BaseReconstructor(metaclass=MetaReconstructor):
__registry__: ClassVar[dict]
__reconstructor_name__: ClassVar[str]

nufft_kwargs: dict[str, Any]
nufft_kwargs: dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
pass

def setup(self, sim: SimData) -> None:
"""Set up the reconstructor."""
Expand Down
50 changes: 19 additions & 31 deletions src/snkf/reconstructors/pysap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

from __future__ import annotations
from dataclasses import field
from typing import Literal, Any
from typing import Any
import logging
import numpy as np

from modopt.opt.linear import LinearParent, Identity
from modopt.opt.proximity import ProximityParent
from modopt.opt.linear import Identity

from .base import BaseReconstructor, SpaceFourierProto
from snkf.simulation import SimData
Expand Down Expand Up @@ -55,7 +54,8 @@ def get_fourier_operator(
if backend_name in ["cufinufft", "stacked-cufinufft"]:
import cupy as cp

smaps = cp.array(smaps)
if smaps is not None:
smaps = cp.array(smaps)
kwargs["smaps_cached"] = True
if sim.extra_infos["traj_params"]["constant"] is True and backend_name != "":
logger.debug("using a duplicated operator.")
Expand Down Expand Up @@ -106,10 +106,9 @@ class ZeroFilledReconstructor(BaseReconstructor):

__reconstructor_name__ = "adjoint"

def __init__(self, nufft_kwargs: dict = None):
if nufft_kwargs is None:
nufft_kwargs = {}
self.nufft_kwargs = nufft_kwargs
def __post_init__(self):
super().__post_init__()
self.fourier_op = None

def setup(self, sim: SimData) -> None:
"""Set up the reconstructor."""
Expand Down Expand Up @@ -223,28 +222,18 @@ class LowRankPlusSparseReconstructor(BaseReconstructor):

__reconstructor_name__ = "lr_f"

def __init__(
self,
lambda_l: float = 0.1,
lambda_s: float | Literal["sure"] = 1,
algorithm: str = "otazo",
max_iter: int = 20,
time_linear_op: LinearParent = None,
time_prox_op: ProximityParent = None,
space_prox_op: ProximityParent = None,
fourier_op: SpaceFourierProto | None = None,
nufft_kwargs: dict[str, Any] | None = None,
):
super().__init__(nufft_kwargs)
self.lambda_l = lambda_l
self.lambda_s = lambda_s
self.max_iter = max_iter
self.algorithm = algorithm

self.time_linear_op = time_linear_op
self.time_prox_op = time_prox_op
self.space_prox_op = space_prox_op
self.fourier_op = fourier_op
nufft_kwargs: dict[str, Any] = field(default_factory=dict)
lambda_l: float = 0.1
lambda_s: float | str = 0.1
algorithm: str = "otazo"
max_iter: int = 20

def __post_init__(self):
super().__post_init__()
self.time_linear_op = None
self.time_prox_op = None
self.space_prox_op = None
self.fourier_op = None

def __str__(self):
if isinstance(self.lambda_s, float):
Expand All @@ -262,7 +251,6 @@ def setup(self, sim: SimData) -> None:
if self.fourier_op is None:
self.fourier_op = get_fourier_operator(
sim,
cartesian_repeat=False,
**self.nufft_kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

def test_scenario1():
cmd = ["src/snkf/cli/main.py", "--config-name=scenario1.yaml"]
result, _err = run_python_script(cmd)
result, _err = run_python_script(cmd, allow_warnings=True)

0 comments on commit 0c5e838

Please sign in to comment.