diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index 6a21880e..366c8ee2 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -19,6 +19,7 @@ import torch.nn as nn import torch.nn.functional as F from monai.networks.blocks import Convolution +from monai.utils import ensure_tuple_rep # To install xformers, use pip install xformers==0.0.16rc401 if importlib.util.find_spec("xformers") is not None: @@ -313,7 +314,7 @@ def __init__( in_channels: int, num_channels: Sequence[int], out_channels: int, - num_res_blocks: int, + num_res_blocks: Sequence[int], norm_num_groups: int, norm_eps: float, attention_levels: Sequence[bool], @@ -350,7 +351,7 @@ def __init__( output_channel = num_channels[i] is_final_block = i == len(num_channels) - 1 - for _ in range(self.num_res_blocks): + for _ in range(self.num_res_blocks[i]): blocks.append( ResBlock( spatial_dims=spatial_dims, @@ -449,7 +450,7 @@ def __init__( num_channels: Sequence[int], in_channels: int, out_channels: int, - num_res_blocks: int, + num_res_blocks: Sequence[int], norm_num_groups: int, norm_eps: float, attention_levels: Sequence[bool], @@ -511,13 +512,14 @@ def __init__( ) reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) block_out_ch = reversed_block_out_channels[0] for i in range(len(reversed_block_out_channels)): block_in_ch = block_out_ch block_out_ch = reversed_block_out_channels[i] is_final_block = i == len(num_channels) - 1 - for _ in range(self.num_res_blocks): + for _ in range(reversed_num_res_blocks[i]): blocks.append( ResBlock( spatial_dims=spatial_dims, @@ -588,7 +590,7 @@ def __init__( spatial_dims: int, in_channels: int = 1, out_channels: int = 1, - num_res_blocks: int = 2, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), num_channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), latent_channels: int = 3, @@ -606,6 +608,12 @@ def __init__( if len(num_channels) != len(attention_levels): raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError("`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.") + self.encoder = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 68a5f882..38b532ae 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -39,6 +39,7 @@ import torch.nn.functional as F from monai.networks.blocks import Convolution, MLPBlock from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep from torch import nn # To install xformers, use pip install xformers==0.0.16rc401 @@ -1610,7 +1611,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - num_res_blocks: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), num_channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, @@ -1642,7 +1643,7 @@ def __init__( raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") if isinstance(num_head_channels, int): - num_head_channels = (num_head_channels,) * len(attention_levels) + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) if len(num_head_channels) != len(attention_levels): raise ValueError( @@ -1650,6 +1651,12 @@ def __init__( " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." ) + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError("`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.") + self.in_channels = in_channels self.block_out_channels = num_channels self.out_channels = out_channels @@ -1693,7 +1700,7 @@ def __init__( in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, - num_res_blocks=num_res_blocks, + num_res_blocks=num_res_blocks[i], norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=not is_final_block, @@ -1725,6 +1732,7 @@ def __init__( # up self.up_blocks = nn.ModuleList([]) reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) reversed_attention_levels = list(reversed(attention_levels)) reversed_num_head_channels = list(reversed(num_head_channels)) output_channel = reversed_block_out_channels[0] @@ -1741,7 +1749,7 @@ def __init__( prev_output_channel=prev_output_channel, out_channels=output_channel, temb_channels=time_embed_dim, - num_res_blocks=num_res_blocks + 1, + num_res_blocks=reversed_num_res_blocks[i] + 1, norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=not is_final_block, diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 54bc8dba..e6280169 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -38,6 +38,21 @@ (1, 1, 16, 16), (1, 4, 4, 4), ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_channels": (4, 4, 4), + "latent_channels": 4, + "attention_levels": (False, False, False), + "num_res_blocks": (1, 1, 2), + "norm_num_groups": 4, + }, + (1, 1, 16, 16), + (1, 1, 16, 16), + (1, 4, 4, 4), + ], [ { "spatial_dims": 2, @@ -161,6 +176,19 @@ def test_model_num_channels_not_same_size_of_attention_levels(self): norm_num_groups=16, ) + def test_model_num_channels_not_same_size_of_num_res_blocks(self): + with self.assertRaises(ValueError): + AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), + latent_channels=8, + num_res_blocks=(8, 8), + norm_num_groups=16, + ) + def test_shape_reconstruction(self): input_param, input_shape, expected_shape, _ = CASES[0] net = AutoencoderKL(**input_param).to(device) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index 1c5c647b..ebda9d31 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -32,6 +32,17 @@ "norm_num_groups": 8, } ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": (1, 1, 2), + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + } + ], [ { "spatial_dims": 2, @@ -270,6 +281,18 @@ def test_attention_levels_with_different_length_num_head_channels(self): norm_num_groups=8, ) + def test_num_res_blocks_with_different_length_num_channels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=(1, 1), + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + norm_num_groups=8, + ) + def test_shape_conditioned_models(self): net = DiffusionModelUNet( spatial_dims=2,