# -*- coding: utf-8 -*- import torch.nn as nn import torch import torch.nn.functional as F import timm from mednext.MedNextV1 import MedNeXt class ConvNext(nn.Module): """ ConvNext is a PyTorch module that implements a convolutional neural network for segmentation tasks. It consists of an encoder and a decoder, with skip connections between them. Args: num_decoder_blocks (int): The number of decoder blocks in the network. Attributes: encoder (nn.Module): The encoder module, which is a pretrained ConvNext model from the timm library. decoder (nn.ModuleList): The decoder module, which is a list of decoder blocks. final (nn.Sequential): The segmentation head module, which performs the final convolution. Methods: decoder_block: Creates a decoder block module. segmentation_head: Creates a segmentation head module. forward: Performs a forward pass through the network. """ def __init__(self, num_decoder_blocks: int = 4) -> None: super(ConvNext, self).__init__() self._num_decoder_blocks = num_decoder_blocks self.dim = 3 self.encoder = timm.create_model( "convnext_atto", pretrained=True, features_only=True, in_chans=1 ) self.decoder = nn.ModuleList() in_channels = 640 out_channels = in_channels // 4 for i in range(self._num_decoder_blocks): self.decoder.append(self.decoder_block(in_channels, out_channels)) in_channels = out_channels * 2 # for skip connections out_channels = out_channels // 2 self.decoder.append(self.decoder_block(out_channels * 2, out_channels)) self.final = self.segmentation_head(out_channels, 1) @staticmethod def decoder_block(in_channels: int, out_channels: int) -> nn.Sequential: """ Creates a decoder block module. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. Returns: nn.Sequential: The decoder block module. """ return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=False), nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), ) @staticmethod def segmentation_head(in_channels: int, out_channels: int) -> nn.Sequential: """ Creates a segmentation head module. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. Returns: nn.Sequential: The segmentation head module. """ return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1), ) def _forward(self, x_in: torch.Tensor) -> torch.Tensor: """ Performs a forward pass through the network. Args: x_in (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor. """ feature_maps = self.encoder(x_in)[::-1] out = feature_maps[0] for idx, layer in enumerate(self.decoder): if idx in range(len(feature_maps)): out = torch.cat([out, feature_maps[idx]], dim=1) out = layer(out) out = self.final(out) return out def forward(self, x_in: torch.Tensor) -> torch.Tensor: """ Forward pass of the model. Args: x_in (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor. """ if x_in.ndim == 5: for i in range(x_in.size(2)): x = x_in[:,:, i, ...] x = self._forward(x) if i == 0: out = x.unsqueeze(2) else: out = torch.cat((out, x.unsqueeze(2)), 2) return out else: return self._forward(x_in) class UNet3D(nn.Module): def __init__(self, in_channels: int = 1, out_channels: int = 1): super(UNet3D, self).__init__() # Encoder (contracting path) self.encoder_conv1 = self.conv_block(in_channels, 64) self.encoder_pool1 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder_conv2 = self.conv_block(64, 128) self.encoder_pool2 = nn.MaxPool3d(kernel_size=2, stride=2) self.encoder_conv3 = self.conv_block(128, 256) self.encoder_pool3 = nn.MaxPool3d(kernel_size=2, stride=2) # Bottleneck self.bottleneck_conv = self.conv_block(256, 512) # Decoder (expansive path) self.decoder_upconv3 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2) self.decoder_conv3 = self.conv_block(512, 256) self.decoder_upconv2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2) self.decoder_conv2 = self.conv_block(256, 128) self.decoder_upconv1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2) self.decoder_conv1 = self.conv_block(128, 64) # Output layer self.output_conv = nn.Conv3d(64, out_channels, kernel_size=1) def conv_block(self, in_channels: int, out_channels: int): return nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Encoder enc1 = self.encoder_conv1(x) enc_pool1 = self.encoder_pool1(enc1) enc2 = self.encoder_conv2(enc_pool1) enc_pool2 = self.encoder_pool2(enc2) enc3 = self.encoder_conv3(enc_pool2) enc_pool3 = self.encoder_pool3(enc3) # Bottleneck bottleneck = self.bottleneck_conv(enc_pool3) # Decoder dec_upconv3 = self.decoder_upconv3(bottleneck) dec_concat3 = torch.cat([dec_upconv3, enc3], dim=1) dec_conv3 = self.decoder_conv3(dec_concat3) dec_upconv2 = self.decoder_upconv2(dec_conv3) dec_concat2 = torch.cat([dec_upconv2, enc2], dim=1) dec_conv2 = self.decoder_conv2(dec_concat2) dec_upconv1 = self.decoder_upconv1(dec_conv2) dec_concat1 = torch.cat([dec_upconv1, enc1], dim=1) dec_conv1 = self.decoder_conv1(dec_concat1) # Output layer output = self.output_conv(dec_conv1) return output class Mednext(nn.Module): def __init__(self, in_channels: int = 1, out_channels: int = 1): super(Mednext, self).__init__() self.model = MedNeXt( in_channels = in_channels, n_channels = 32, n_classes = out_channels, exp_r=2, kernel_size=3, deep_supervision=False, do_res=True, do_res_up_down = True, block_counts = [2,2,2,2,2,2,2,2,2] ) def forward(self, x): return self.model(x)