Skip to content

Commit

Permalink
All the unet weights should now be initialized with the right dtype.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 15, 2023
1 parent cf3974c commit ae43f09
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 23 deletions.
12 changes: 6 additions & 6 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def init_(tensor):

# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
def __init__(self, dim_in, dim_out, dtype=None):
super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2)
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)

def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
Expand All @@ -68,7 +68,7 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None)
project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)

self.net = nn.Sequential(
project_in,
Expand All @@ -89,8 +89,8 @@ def zero_module(module):
return module


def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def Normalize(in_channels, dtype=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)


class SpatialSelfAttention(nn.Module):
Expand Down Expand Up @@ -594,7 +594,7 @@ def __init__(self, in_channels, n_heads, d_head,
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.norm = Normalize(in_channels, dtype=dtype)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
Expand Down
36 changes: 21 additions & 15 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)

def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels
Expand Down Expand Up @@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
Expand All @@ -169,7 +169,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
)
else:
assert self.channels == self.out_channels
Expand Down Expand Up @@ -220,31 +220,31 @@ def __init__(
self.use_scale_shift_norm = use_scale_shift_norm

self.in_layers = nn.Sequential(
normalization(channels),
normalization(channels, dtype=dtype),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
)

self.updown = up or down

if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
else:
self.h_upd = self.x_upd = nn.Identity()

self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
normalization(self.out_channels, dtype=dtype),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
Expand Down Expand Up @@ -604,6 +604,7 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
)
]
ch = mult * model_channels
Expand Down Expand Up @@ -651,10 +652,11 @@ def __init__(
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dtype=self.dtype
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
)
)
)
Expand All @@ -679,6 +681,7 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
),
AttentionBlock(
ch,
Expand All @@ -698,6 +701,7 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
),
)
self._feature_size += ch
Expand All @@ -715,6 +719,7 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
)
]
ch = model_channels * mult
Expand Down Expand Up @@ -758,18 +763,19 @@ def __init__(
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
dtype=self.dtype
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch

self.out = nn.Sequential(
normalization(ch),
normalization(ch, dtype=self.dtype),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/modules/diffusionmodules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,13 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))


def normalization(channels):
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
return GroupNorm32(32, channels, dtype=dtype)


# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
Expand Down

0 comments on commit ae43f09

Please sign in to comment.