Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds 3D vSHARP model #273

Merged
merged 10 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ BROWSER := python -c "$$BROWSER_PYSCRIPT"
help:
@python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)

clean: clean-build clean-pyc clean-cpy clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts
clean: clean-build clean-pyc clean-cpy clean-ipynb clean-test clean-docs ## remove all build, test, coverage, docs and Python and cython artifacts

clean-build: ## remove build artifacts
rm -fr build/
Expand All @@ -46,6 +46,9 @@ clean-cpy: ## remove cython file artifacts
find . -name '*.cpp' -exec rm -f {} +
find . -name '*.so' -exec rm -f {} +

clean-ipynb: ## remove ipynb artifacts
find . -name '.ipynb_checkpoints' -exec rm -rf {} +

clean-test: ## remove test and coverage artifacts
rm -fr .tox/
rm -f .coverage
Expand Down
14 changes: 12 additions & 2 deletions direct/nn/unet/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass

from direct.config.defaults import ModelConfig
from direct.nn.types import InitType


@dataclass
Expand Down Expand Up @@ -30,4 +31,13 @@ class Unet2dConfig(ModelConfig):
dropout_probability: float = 0.0
skip_connection: bool = False
normalized: bool = False
image_initialization: str = "zero_filled"
image_initialization: InitType = InitType.ZERO_FILLED


@dataclass
class UnetModel3dConfig(ModelConfig):
in_channels: int = 2
out_channels: int = 2
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
15 changes: 8 additions & 7 deletions direct/nn/unet/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.nn import functional as F

from direct.data import transforms as T
from direct.nn.types import InitType


class ConvBlock(nn.Module):
Expand Down Expand Up @@ -334,7 +335,7 @@ def __init__(
dropout_probability: float,
skip_connection: bool = False,
normalized: bool = False,
image_initialization: str = "zero_filled",
image_initialization: InitType = InitType.ZERO_FILLED,
**kwargs,
):
"""Inits :class:`Unet2d`.
Expand All @@ -355,8 +356,8 @@ def __init__(
If True, skip connection is used for the output. Default: False.
normalized: bool
If True, Normalized Unet is used. Default: False.
image_initialization: str
Type of image initialization. Default: "zero-filled".
image_initialization: InitType
Type of image initialization. Default: InitType.ZERO_FILLED.
kwargs: dict
"""
super().__init__()
Expand Down Expand Up @@ -437,18 +438,18 @@ def forward(
output: torch.Tensor
Output image of shape (N, height, width, complex=2).
"""
if self.image_initialization == "sense":
if self.image_initialization == InitType.SENSE:
if sensitivity_map is None:
raise ValueError("Expected sensitivity_map not to be None with 'sense' image_initialization.")
raise ValueError("Expected sensitivity_map not to be None with InitType.SENSE image_initialization.")
input_image = self.compute_sense_init(
kspace=masked_kspace,
sensitivity_map=sensitivity_map,
)
elif self.image_initialization == "zero_filled":
elif self.image_initialization == InitType.ZERO_FILLED:
input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim)
else:
raise ValueError(
f"Unknown image_initialization. Expected `sense` or `zero_filled`. "
f"Unknown image_initialization. Expected InitType.ZERO_FILLED or InitType.SENSE. "
f"Got {self.image_initialization}."
)

Expand Down
Loading
Loading