Skip to content

Commit

Permalink
Adds 3D vSHARP model (#273)
Browse files Browse the repository at this point in the history
* Adding 3D UNet & config
* Add 3d vsharp & config
* Minor fixes in typing and software package versions
  • Loading branch information
georgeyiasemis authored Apr 2, 2024
1 parent b660012 commit 4355037
Show file tree
Hide file tree
Showing 21 changed files with 1,122 additions and 49 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ BROWSER := python -c "$$BROWSER_PYSCRIPT"
help:
@python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)

clean: clean-build clean-pyc clean-cpy clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts
clean: clean-build clean-pyc clean-cpy clean-ipynb clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts

clean-build: ## remove build artifacts
rm -fr build/
Expand All @@ -46,6 +46,9 @@ clean-cpy: ## remove cython file artifacts
find . -name '*.cpp' -exec rm -f {} +
find . -name '*.so' -exec rm -f {} +

clean-ipynb: ## remove ipynb artifacts
find . -name '.ipynb_checkpoints' -exec rm -rf {} +

clean-test: ## remove test and coverage artifacts
rm -fr .tox/
rm -f .coverage
Expand Down
14 changes: 12 additions & 2 deletions direct/nn/unet/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass

from direct.config.defaults import ModelConfig
from direct.nn.types import InitType


@dataclass
Expand Down Expand Up @@ -30,4 +31,13 @@ class Unet2dConfig(ModelConfig):
dropout_probability: float = 0.0
skip_connection: bool = False
normalized: bool = False
image_initialization: str = "zero_filled"
image_initialization: InitType = InitType.ZERO_FILLED


@dataclass
class UnetModel3dConfig(ModelConfig):
in_channels: int = 2
out_channels: int = 2
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
15 changes: 8 additions & 7 deletions direct/nn/unet/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.nn import functional as F

from direct.data import transforms as T
from direct.nn.types import InitType


class ConvBlock(nn.Module):
Expand Down Expand Up @@ -334,7 +335,7 @@ def __init__(
dropout_probability: float,
skip_connection: bool = False,
normalized: bool = False,
image_initialization: str = "zero_filled",
image_initialization: InitType = InitType.ZERO_FILLED,
**kwargs,
):
"""Inits :class:`Unet2d`.
Expand All @@ -355,8 +356,8 @@ def __init__(
If True, skip connection is used for the output. Default: False.
normalized: bool
If True, Normalized Unet is used. Default: False.
image_initialization: str
Type of image initialization. Default: "zero-filled".
image_initialization: InitType
Type of image initialization. Default: InitType.ZERO_FILLED.
kwargs: dict
"""
super().__init__()
Expand Down Expand Up @@ -437,18 +438,18 @@ def forward(
output: torch.Tensor
Output image of shape (N, height, width, complex=2).
"""
if self.image_initialization == "sense":
if self.image_initialization == InitType.SENSE:
if sensitivity_map is None:
raise ValueError("Expected sensitivity_map not to be None with 'sense' image_initialization.")
raise ValueError("Expected sensitivity_map not to be None with InitType.SENSE image_initialization.")
input_image = self.compute_sense_init(
kspace=masked_kspace,
sensitivity_map=sensitivity_map,
)
elif self.image_initialization == "zero_filled":
elif self.image_initialization == InitType.ZERO_FILLED:
input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim)
else:
raise ValueError(
f"Unknown image_initialization. Expected `sense` or `zero_filled`. "
f"Unknown image_initialization. Expected InitType.ZERO_FILLED or InitType.SENSE. "
f"Got {self.image_initialization}."
)

Expand Down
Loading

0 comments on commit 4355037

Please sign in to comment.