Skip to content

Commit

Permalink
Change num_res_blocks to Sequence[int] | int (#238)
Browse files Browse the repository at this point in the history
* Change num_res_blocks to Sequence[int] | int

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Change num_res_blocks to Sequence[int] | int

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Use ensure_tuple_rep

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Update generative/networks/nets/diffusion_model_unet.py

Co-authored-by: Eric Kerfoot <[email protected]>

* Update generative/networks/nets/autoencoderkl.py

Co-authored-by: Eric Kerfoot <[email protected]>

---------

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
Warvito and ericspod authored Feb 10, 2023
1 parent 8a329b9 commit 803ed12
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 9 deletions.
18 changes: 13 additions & 5 deletions generative/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.blocks import Convolution
from monai.utils import ensure_tuple_rep

# To install xformers, use pip install xformers==0.0.16rc401
if importlib.util.find_spec("xformers") is not None:
Expand Down Expand Up @@ -313,7 +314,7 @@ def __init__(
in_channels: int,
num_channels: Sequence[int],
out_channels: int,
num_res_blocks: int,
num_res_blocks: Sequence[int],
norm_num_groups: int,
norm_eps: float,
attention_levels: Sequence[bool],
Expand Down Expand Up @@ -350,7 +351,7 @@ def __init__(
output_channel = num_channels[i]
is_final_block = i == len(num_channels) - 1

for _ in range(self.num_res_blocks):
for _ in range(self.num_res_blocks[i]):
blocks.append(
ResBlock(
spatial_dims=spatial_dims,
Expand Down Expand Up @@ -449,7 +450,7 @@ def __init__(
num_channels: Sequence[int],
in_channels: int,
out_channels: int,
num_res_blocks: int,
num_res_blocks: Sequence[int],
norm_num_groups: int,
norm_eps: float,
attention_levels: Sequence[bool],
Expand Down Expand Up @@ -511,13 +512,14 @@ def __init__(
)

reversed_attention_levels = list(reversed(attention_levels))
reversed_num_res_blocks = list(reversed(num_res_blocks))
block_out_ch = reversed_block_out_channels[0]
for i in range(len(reversed_block_out_channels)):
block_in_ch = block_out_ch
block_out_ch = reversed_block_out_channels[i]
is_final_block = i == len(num_channels) - 1

for _ in range(self.num_res_blocks):
for _ in range(reversed_num_res_blocks[i]):
blocks.append(
ResBlock(
spatial_dims=spatial_dims,
Expand Down Expand Up @@ -588,7 +590,7 @@ def __init__(
spatial_dims: int,
in_channels: int = 1,
out_channels: int = 1,
num_res_blocks: int = 2,
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
num_channels: Sequence[int] = (32, 64, 64, 64),
attention_levels: Sequence[bool] = (False, False, True, True),
latent_channels: int = 3,
Expand All @@ -606,6 +608,12 @@ def __init__(
if len(num_channels) != len(attention_levels):
raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels")

if isinstance(num_res_blocks, int):
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))

if len(num_res_blocks) != len(num_channels):
raise ValueError("`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.")

self.encoder = Encoder(
spatial_dims=spatial_dims,
in_channels=in_channels,
Expand Down
16 changes: 12 additions & 4 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import torch.nn.functional as F
from monai.networks.blocks import Convolution, MLPBlock
from monai.networks.layers.factories import Pool
from monai.utils import ensure_tuple_rep
from torch import nn

# To install xformers, use pip install xformers==0.0.16rc401
Expand Down Expand Up @@ -1610,7 +1611,7 @@ def __init__(
spatial_dims: int,
in_channels: int,
out_channels: int,
num_res_blocks: int,
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
num_channels: Sequence[int] = (32, 64, 64, 64),
attention_levels: Sequence[bool] = (False, False, True, True),
norm_num_groups: int = 32,
Expand Down Expand Up @@ -1642,14 +1643,20 @@ def __init__(
raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels")

if isinstance(num_head_channels, int):
num_head_channels = (num_head_channels,) * len(attention_levels)
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))

if len(num_head_channels) != len(attention_levels):
raise ValueError(
"num_head_channels should have the same length as attention_levels. For the i levels without attention,"
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
)

if isinstance(num_res_blocks, int):
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))

if len(num_res_blocks) != len(num_channels):
raise ValueError("`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.")

self.in_channels = in_channels
self.block_out_channels = num_channels
self.out_channels = out_channels
Expand Down Expand Up @@ -1693,7 +1700,7 @@ def __init__(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
num_res_blocks=num_res_blocks,
num_res_blocks=num_res_blocks[i],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
add_downsample=not is_final_block,
Expand Down Expand Up @@ -1725,6 +1732,7 @@ def __init__(
# up
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(num_channels))
reversed_num_res_blocks = list(reversed(num_res_blocks))
reversed_attention_levels = list(reversed(attention_levels))
reversed_num_head_channels = list(reversed(num_head_channels))
output_channel = reversed_block_out_channels[0]
Expand All @@ -1741,7 +1749,7 @@ def __init__(
prev_output_channel=prev_output_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
num_res_blocks=num_res_blocks + 1,
num_res_blocks=reversed_num_res_blocks[i] + 1,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
add_upsample=not is_final_block,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@
(1, 1, 16, 16),
(1, 4, 4, 4),
],
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_channels": (4, 4, 4),
"latent_channels": 4,
"attention_levels": (False, False, False),
"num_res_blocks": (1, 1, 2),
"norm_num_groups": 4,
},
(1, 1, 16, 16),
(1, 1, 16, 16),
(1, 4, 4, 4),
],
[
{
"spatial_dims": 2,
Expand Down Expand Up @@ -161,6 +176,19 @@ def test_model_num_channels_not_same_size_of_attention_levels(self):
norm_num_groups=16,
)

def test_model_num_channels_not_same_size_of_num_res_blocks(self):
with self.assertRaises(ValueError):
AutoencoderKL(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=(24, 24, 24),
attention_levels=(False, False, False),
latent_channels=8,
num_res_blocks=(8, 8),
norm_num_groups=16,
)

def test_shape_reconstruction(self):
input_param, input_shape, expected_shape, _ = CASES[0]
net = AutoencoderKL(**input_param).to(device)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@
"norm_num_groups": 8,
}
],
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_res_blocks": (1, 1, 2),
"num_channels": (8, 8, 8),
"attention_levels": (False, False, False),
"norm_num_groups": 8,
}
],
[
{
"spatial_dims": 2,
Expand Down Expand Up @@ -270,6 +281,18 @@ def test_attention_levels_with_different_length_num_head_channels(self):
norm_num_groups=8,
)

def test_num_res_blocks_with_different_length_num_channels(self):
with self.assertRaises(ValueError):
DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_res_blocks=(1, 1),
num_channels=(8, 8, 8),
attention_levels=(False, False, False),
norm_num_groups=8,
)

def test_shape_conditioned_models(self):
net = DiffusionModelUNet(
spatial_dims=2,
Expand Down

0 comments on commit 803ed12

Please sign in to comment.