Skip to content

Commit

Permalink
Add non-static and kt sampling (#280)
Browse files Browse the repository at this point in the history
* kt sampling mask functions: `KtGaussian1DMaskFunc`, `KtRadialMaskFunc`, `KtUniformMaskFunc`,
* Non-static sampling (dynamic/multislice) dicitated by the `MaskFuncMode`, which can be STATIC, MULTISLICE, DYNAMIC
* Corresponding tests
  • Loading branch information
georgeyiasemis authored Jul 3, 2024
1 parent d733b81 commit e9e63d6
Show file tree
Hide file tree
Showing 8 changed files with 1,711 additions and 469 deletions.
1,821 changes: 1,464 additions & 357 deletions direct/common/subsample.py

Large diffs are not rendered by default.

17 changes: 10 additions & 7 deletions direct/common/subsample_config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional

from omegaconf import MISSING

from direct.config.defaults import BaseConfig
from direct.types import MaskFuncMode


@dataclass
class MaskingConfig(BaseConfig):
name: str = MISSING
accelerations: Tuple[int, ...] = (5,) # Ideally Union[float, int].
center_fractions: Optional[Tuple[float, ...]] = (0.1,) # Ideally Optional[Tuple[float, ...]]
accelerations: tuple[float, ...] = (5.0,)
center_fractions: Optional[tuple[float, ...]] = (0.1,)
uniform_range: bool = False
image_center_crop: bool = False
mode: MaskFuncMode = MaskFuncMode.STATIC

val_accelerations: Tuple[int, ...] = (5, 10)
val_center_fractions: Optional[Tuple[float, ...]] = (0.1, 0.05)
val_accelerations: tuple[float, ...] = (5.0, 10.0)
val_center_fractions: Optional[tuple[float, ...]] = (0.1, 0.05)
5 changes: 1 addition & 4 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
Sample with `sampling_mask` key.
"""
if not self.shape:
shape = sample["kspace"].shape[-3:]
shape = sample["kspace"].shape[1:]
elif any(_ is None for _ in self.shape): # Allow None as values.
kspace_shape = list(sample["kspace"].shape[1:-1])
shape = tuple(_ if _ else kspace_shape[idx] for idx, _ in enumerate(self.shape)) + (2,)
Expand All @@ -328,9 +328,6 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:

sampling_mask = self.mask_func(shape=shape, seed=seed, return_acs=False)

if sample["kspace"].ndim == 5:
sampling_mask = sampling_mask.unsqueeze(0)

if "padding" in sample:
sampling_mask = T.apply_padding(sampling_mask, sample["padding"])

Expand Down
8 changes: 8 additions & 0 deletions direct/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from enum import Enum
from typing import NewType, Union

import numpy as np
import torch
from omegaconf.omegaconf import DictConfig
from torch import nn as nn
Expand All @@ -19,6 +20,7 @@
FileOrUrl = NewType("FileOrUrl", PathOrString)
HasStateDict = Union[nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, GradScaler]
TensorOrNone = Union[None, torch.Tensor]
TensorOrNdarray = Union[torch.Tensor, np.ndarray]


class DirectEnum(str, Enum):
Expand Down Expand Up @@ -59,6 +61,12 @@ class TransformKey(DirectEnum):
SCALING_FACTOR = "scaling_factor"


class MaskFuncMode(DirectEnum):
STATIC = "static"
DYNAMIC = "dynamic"
MULTISLICE = "multislice"


class IntegerListOrTupleStringMeta(type):
"""Metaclass for the :class:`IntegerListOrTupleString` class.
Expand Down
3 changes: 2 additions & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
)
from direct.launch import launch
from direct.train import setup_train
from direct.types import MaskFuncMode


def create_test_transform_cfg(transforms_type):
transforms_config = TransformsConfig(
normalization=NormalizationTransformConfig(scaling_key="masked_kspace"),
masking=MaskingConfig(name="FastMRIRandom"),
masking=MaskingConfig(name="FastMRIRandom", mode=MaskFuncMode.STATIC),
cropping=CropTransformConfig(crop="(32, 32)"),
sensitivity_map_estimation=SensitivityMapEstimationTransformConfig(estimate_sensitivity_maps=True),
transforms_type=transforms_type,
Expand Down
Loading

0 comments on commit e9e63d6

Please sign in to comment.