Skip to content

Commit

Permalink
Added SPADE-LDM code (#436)
Browse files Browse the repository at this point in the history
* Added SPADE-LDM code:

    Modification of diffusion_model_unet to allow for SPADE normalisation to be set up as an option
    Modification of autoencoder_kl to allow for SPADE normalisation to be set up as an option
    Modification of inferer and latent inferer to allow for label to be passed through forward when SPADE is active
    Addition of tests to: test_spade_diffusion
    Creation of tutorial for 2D using OASIS subset of images.

Even though I implemented tests, we should check very thoroughly that this works before merging, especially since the presence of SPADE norm needs for labels to be passed to the forward method, and ANY call of forward without a label if SPADE Is on will end up in error. In the same fashion, we should ensure that ANY call on forward when SPADE is not on is not disrupted (code doesn't error out because of a label missing).

* Fetch tutorial from other PR

* Made sure norm_params for SPADE had a single affine argument.

* Code formatting.

---------

Co-authored-by: virginiafdez <[email protected]>
Co-authored-by: Mark Graham <[email protected]>
  • Loading branch information
3 people authored Nov 16, 2023
1 parent 3da2673 commit e0e2559
Show file tree
Hide file tree
Showing 12 changed files with 4,882 additions and 171 deletions.
190 changes: 143 additions & 47 deletions generative/inferers/inferer.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions generative/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
from .controlnet import ControlNet
from .diffusion_model_unet import DiffusionModelUNet
from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
from .spade_autoencoderkl import SPADEAutoencoderKL
from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
from .spade_network import SPADENet
from .transformer import DecoderOnlyTransformer
from .vqvae import VQVAE
38 changes: 18 additions & 20 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -964,7 +964,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
dropout=dropout_cattn,
)
)

Expand Down Expand Up @@ -1103,7 +1103,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> None:
super().__init__()
self.attention = None
Expand All @@ -1127,7 +1127,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
dropout=dropout_cattn,
)
self.resnet_2 = ResnetBlock(
spatial_dims=spatial_dims,
Expand Down Expand Up @@ -1271,7 +1271,7 @@ def __init__(
add_upsample: bool = True,
resblock_updown: bool = False,
num_head_channels: int = 1,
use_flash_attention: bool = False
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -1388,7 +1388,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -1422,7 +1422,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
dropout=dropout_cattn,
)
)

Expand Down Expand Up @@ -1486,7 +1486,7 @@ def get_down_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> nn.Module:
if with_attn:
return AttnDownBlock(
Expand Down Expand Up @@ -1518,7 +1518,7 @@ def get_down_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)
else:
return DownBlock(
Expand Down Expand Up @@ -1546,7 +1546,7 @@ def get_mid_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> nn.Module:
if with_conditioning:
return CrossAttnMidBlock(
Expand All @@ -1560,7 +1560,7 @@ def get_mid_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)
else:
return AttnMidBlock(
Expand Down Expand Up @@ -1592,7 +1592,7 @@ def get_up_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> nn.Module:
if with_attn:
return AttnUpBlock(
Expand Down Expand Up @@ -1626,7 +1626,7 @@ def get_up_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)
else:
return UpBlock(
Expand Down Expand Up @@ -1688,7 +1688,7 @@ def __init__(
num_class_embeds: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
dropout_cattn: float = 0.0,
) -> None:
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
Expand All @@ -1701,9 +1701,7 @@ def __init__(
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
)
if dropout_cattn > 1.0 or dropout_cattn < 0.0:
raise ValueError(
"Dropout cannot be negative or >1.0!"
)
raise ValueError("Dropout cannot be negative or >1.0!")

# All number of channels should be multiple of num_groups
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
Expand Down Expand Up @@ -1793,7 +1791,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)

self.down_blocks.append(down_block)
Expand All @@ -1811,7 +1809,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)

# up
Expand Down Expand Up @@ -1846,7 +1844,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
dropout_cattn=dropout_cattn,
)

self.up_blocks.append(up_block)
Expand Down
Loading

0 comments on commit e0e2559

Please sign in to comment.