Skip to content

Commit

Permalink
deduplicate some norm layers, clean up some init logic
Browse files Browse the repository at this point in the history
  • Loading branch information
neggles committed Jul 13, 2024
1 parent 8652d8c commit b9c9068
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 105 deletions.
51 changes: 15 additions & 36 deletions src/neurosis/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
import math
from functools import wraps
from typing import Any, Iterable, Optional
from typing import Any, Optional

import torch
from einops import rearrange, repeat
from packaging import version
from torch import FloatTensor, Tensor, nn
from torch import Tensor, nn
from torch.backends.cuda import SDPBackend, sdp_kernel
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint

from neurosis.modules.diffusion.util import zero_module
from neurosis.modules.layers import Normalize

logger = logging.getLogger(__name__)

try:
Expand Down Expand Up @@ -45,28 +47,6 @@
}


def uniq(arr: Iterable[Any]) -> list:
return list(set(arr))


def max_neg_value(t: Tensor) -> FloatTensor:
return -torch.finfo(t.dtype).max


def zero_module(module: nn.Module) -> nn.Module:
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module


@wraps(nn.GroupNorm.__init__)
def get_norm_layer(in_channels: int) -> nn.GroupNorm:
return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
Expand Down Expand Up @@ -173,7 +153,7 @@ def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels

self.norm = get_norm_layer(in_channels)
self.norm = Normalize(in_channels)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
Expand Down Expand Up @@ -614,24 +594,23 @@ def __init__(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)

if context_dim is not None and not isinstance(context_dim, (list)):
context_dim = [context_dim]
if context_dim is not None and isinstance(context_dim, list):
if depth != len(context_dim):
if context_dim is not None:
if not isinstance(context_dim, list):
context_dim = [context_dim]
if len(context_dim) != depth:
logger.debug(
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
assert all(
map(lambda x: x == context_dim[0], context_dim)
), "need homogenous context_dim to match depth automatically"
context_dim = depth * [context_dim[0]]
elif context_dim is None:
if not all((x == context_dim[0] for x in context_dim)):
raise ValueError("need homogenous context_dim to match depth automatically")
context_dim = [context_dim[0]] * depth
else:
context_dim = [None] * depth

self.in_channels = in_channels
self.norm = get_norm_layer(in_channels)
self.norm = Normalize(in_channels)

inner_dim = n_heads * d_head
if not use_linear:
Expand Down
21 changes: 8 additions & 13 deletions src/neurosis/modules/diffusion/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# pytorch_diffusion + derived encoder decoder
import logging
import math
from functools import wraps
from typing import Any, Optional, Sequence
from warnings import warn

Expand All @@ -12,6 +11,7 @@
from torch.nn import functional as F

from neurosis.modules.attention import LinearAttention, MemoryEfficientCrossAttention
from neurosis.modules.layers import Normalize
from neurosis.modules.regularizers import DiagonalGaussianRegularizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,11 +48,6 @@ def get_timestep_embedding(timesteps: Tensor, embedding_dim: int) -> Tensor:
return emb


@wraps(nn.GroupNorm.__init__)
def get_norm_layer(in_channels: int) -> nn.GroupNorm:
return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class Upsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool):
super().__init__()
Expand Down Expand Up @@ -103,11 +98,11 @@ def __init__(
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut

self.norm1 = get_norm_layer(in_channels)
self.norm1 = Normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = get_norm_layer(out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
Expand Down Expand Up @@ -151,7 +146,7 @@ def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels

self.norm = get_norm_layer(in_channels)
self.norm = Normalize(in_channels)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
Expand Down Expand Up @@ -189,7 +184,7 @@ def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels

self.norm = get_norm_layer(in_channels)
self.norm = Normalize(in_channels)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
Expand Down Expand Up @@ -404,7 +399,7 @@ def __init__(
self.up.insert(0, up) # prepend to get consistent order

# end
self.norm_out = get_norm_layer(block_in)
self.norm_out = Normalize(block_in)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, x: Tensor, t: Optional[Tensor] = None, context: Optional[Tensor] = None) -> Tensor:
Expand Down Expand Up @@ -539,7 +534,7 @@ def __init__(
)

# end
self.norm_out = get_norm_layer(block_in)
self.norm_out = Normalize(block_in)
self.conv_out = nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
Expand Down Expand Up @@ -698,7 +693,7 @@ def __init__(
self.up.insert(0, up) # prepend to get consistent order

# end
self.norm_out = get_norm_layer(block_in)
self.norm_out = Normalize(block_in)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
self.max_batch_size = None

Expand Down
10 changes: 5 additions & 5 deletions src/neurosis/modules/diffusion/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from neurosis.modules.diffusion.util import (
avg_pool_nd,
conv_nd,
normalization,
timestep_embedding,
zero_module,
)
from neurosis.modules.layers import Normalize

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -246,7 +246,7 @@ def __init__(
padding = kernel_size // 2

self.in_layers = nn.Sequential(
normalization(channels),
Normalize(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
)
Expand Down Expand Up @@ -280,7 +280,7 @@ def __init__(
)

self.out_layers = nn.Sequential(
normalization(self.out_channels),
Normalize(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
Expand Down Expand Up @@ -368,7 +368,7 @@ def __init__(
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.norm = Normalize(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
Expand Down Expand Up @@ -796,7 +796,7 @@ def __init__(
self._feature_size += ch

self.out = nn.Sequential(
normalization(ch),
Normalize(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
Expand Down
52 changes: 22 additions & 30 deletions src/neurosis/modules/diffusion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

import math
from functools import wraps
from typing import Any, Callable, Optional, Sequence

import numpy as np
Expand Down Expand Up @@ -181,7 +182,7 @@ def zero_module(module: nn.Module) -> nn.Module:
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
nn.init.zeros_(p)
return module


Expand All @@ -201,45 +202,36 @@ def mean_flat(tensor: Tensor) -> Tensor:
return tensor.mean(dim=list(range(1, len(tensor.shape))))


class GroupNorm32(nn.GroupNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)


def normalization(channels: int) -> GroupNorm32:
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)


@wraps(nn.modules.conv._ConvNd)
def conv_nd(dims, *args, **kwargs) -> nn.Conv1d | nn.Conv2d | nn.Conv3d:
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
match dims:
case 1:
return nn.Conv1d(*args, **kwargs)
case 2:
return nn.Conv2d(*args, **kwargs)
case 3:
return nn.Conv3d(*args, **kwargs)
case _:
raise ValueError(f"unsupported dimensions: {dims}")


@wraps(nn.modules.pooling._AvgPoolNd)
def avg_pool_nd(dims: int, *args, **kwargs) -> nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d:
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
match dims:
case 1:
return nn.AvgPool1d(*args, **kwargs)
case 2:
return nn.AvgPool2d(*args, **kwargs)
case 3:
return nn.AvgPool3d(*args, **kwargs)
case _:
raise ValueError(f"unsupported dimensions: {dims}")


class AlphaBlender(nn.Module):
Expand Down
24 changes: 3 additions & 21 deletions src/neurosis/modules/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,9 @@
from torch import Tensor, nn


class GroupNorm32(nn.GroupNorm):
def __init__(
self,
num_channels: int,
eps: float = 1e-5,
affine: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__(
num_groups=32,
num_channels=num_channels,
eps=eps,
affine=affine,
device=device,
dtype=dtype,
)


def normalization(channels: int) -> GroupNorm32:
return GroupNorm32(channels)
class Normalize(nn.GroupNorm):
def __init__(self, in_channels: int, num_groups: int = 32):
super().__init__(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)


class ActNorm(nn.Module):
Expand Down

0 comments on commit b9c9068

Please sign in to comment.