Skip to content

Commit

Permalink
Merge pull request #122 from directgroup/implement-unet-baseline
Browse files Browse the repository at this point in the history
Implement unet baseline
  • Loading branch information
georgeyiasemis authored Oct 5, 2021
2 parents 9279eb8 + 597cdf9 commit 36f4fa0
Show file tree
Hide file tree
Showing 3 changed files with 584 additions and 1 deletion.
10 changes: 10 additions & 0 deletions direct/nn/unet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ class UnetModel2dConfig(ModelConfig):
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0


@dataclass
class Unet2dConfig(ModelConfig):
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
skip_connection: bool = False
normalized: bool = False
image_initialization: str = "zero_filled"
102 changes: 101 additions & 1 deletion direct/nn/unet/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

# Code borrowed / edited from: https://github.com/facebookresearch/fastMRI/blob/
import math
from typing import List, Tuple
from typing import Callable, List, Optional, Tuple

import torch
from torch import nn
from torch.nn import functional as F

from direct.data import transforms as T


class ConvBlock(nn.Module):
"""
Expand Down Expand Up @@ -312,3 +314,101 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
output = self.unnorm(output, mean, std)

return output


class Unet2d(nn.Module):
"""
PyTorch implementation of a U-Net model for MRI Reconstruction.
"""

def __init__(
self,
forward_operator: Callable,
backward_operator: Callable,
num_filters: int,
num_pool_layers: int,
dropout_probability: float,
skip_connection: bool = False,
normalized: bool = False,
image_initialization: str = "zero_filled",
**kwargs,
):
super().__init__()

extra_keys = kwargs.keys()
for extra_key in extra_keys:
if extra_key not in [
"sensitivity_map_model",
"model_name",
]:
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")

if normalized:
self.unet = NormUnetModel2d(
in_channels=2,
out_channels=2,
num_filters=num_filters,
num_pool_layers=num_pool_layers,
dropout_probability=dropout_probability,
)
else:
self.unet = UnetModel2d(
in_channels=2,
out_channels=2,
num_filters=num_filters,
num_pool_layers=num_pool_layers,
dropout_probability=dropout_probability,
)

self.forward_operator = forward_operator
self.backward_operator = backward_operator

self.skip_connection = skip_connection

self.image_initialization = image_initialization

self._coil_dim = 1
self._spatial_dims = (2, 3)

def compute_sense_init(self, kspace, sensitivity_map, spatial_dims, coil_dim):

input_image = T.complex_multiplication(
T.conjugate(sensitivity_map),
self.backward_operator(kspace, dim=spatial_dims),
) # shape (batch, coil, height, width, complex=2)

input_image = input_image.sum(coil_dim)

# shape (batch, height, width, complex=2)
return input_image

def forward(
self,
masked_kspace: torch.Tensor,
sensitivity_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# kspace and sensitivity_map are of shape: (batch, coil, height, width, complex=2)

if self.image_initialization == "sense":
if sensitivity_map is None:
raise ValueError("Expected sensitivity_map not to be None with 'sense' image_initialization.")
input_image = self.compute_sense_init(
kspace=masked_kspace,
sensitivity_map=sensitivity_map,
spatial_dims=self._spatial_dims,
coil_dim=self._coil_dim,
)
elif self.image_initialization == "zero_filled":
input_image = self.backward_operator(masked_kspace).sum(self._coil_dim)
else:
raise ValueError(
f"Unknown image_initialization. Expected `sense` or `zero_filled`. "
f"Got {self.image_initialization}."
)

output = self.unet(input_image.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)

if self.skip_connection:
output += input_image

return output
Loading

0 comments on commit 36f4fa0

Please sign in to comment.