Skip to content

Commit

Permalink
Initialize transformer unet block weights in right dtype at the start.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jun 15, 2023
1 parent 6253ec4 commit e21d9ad
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 44 deletions.
82 changes: 41 additions & 41 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,19 @@ def forward(self, x):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim),
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)

self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
comfy.ops.Linear(inner_dim, dim_out)
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype)
)

def forward(self, x):
Expand Down Expand Up @@ -147,20 +147,20 @@ def forward(self, x):


class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)

self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim),
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
nn.Dropout(dropout)
)

Expand Down Expand Up @@ -244,20 +244,20 @@ def forward(self, x, context=None, value=None, mask=None):


class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)

self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim),
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
nn.Dropout(dropout)
)

Expand Down Expand Up @@ -342,20 +342,20 @@ def forward(self, x, context=None, value=None, mask=None):
return self.to_out(r2)

class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)

self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim),
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
nn.Dropout(dropout)
)

Expand Down Expand Up @@ -398,7 +398,7 @@ def forward(self, x, context=None, value=None, mask=None):

class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
Expand All @@ -408,11 +408,11 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
self.heads = heads
self.dim_head = dim_head

self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)

self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None

def forward(self, x, context=None, value=None, mask=None):
Expand Down Expand Up @@ -449,19 +449,19 @@ def forward(self, x, context=None, value=None, mask=None):
return self.to_out(out)

class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.heads = heads
self.dim_head = dim_head

self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)

self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None

def forward(self, x, context=None, value=None, mask=None):
Expand Down Expand Up @@ -507,17 +507,17 @@ def forward(self, x, context=None, value=None, mask=None):

class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False):
disable_self_attn=False, dtype=None):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype)
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
self.checkpoint = checkpoint

def forward(self, x, context=None, transformer_options={}):
Expand Down Expand Up @@ -588,7 +588,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True):
use_checkpoint=True, dtype=None):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
Expand All @@ -600,22 +600,22 @@ def __init__(self, in_channels, n_heads, d_head,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
padding=0, dtype=dtype)
else:
self.proj_in = comfy.ops.Linear(in_channels, inner_dim)
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)

self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype)
for d in range(depth)]
)
if not use_linear:
self.proj_out = nn.Conv2d(inner_dim,in_channels,
kernel_size=1,
stride=1,
padding=0)
padding=0, dtype=dtype)
else:
self.proj_out = comfy.ops.Linear(in_channels, inner_dim)
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
self.use_linear = use_linear

def forward(self, x, context=None, transformer_options={}):
Expand Down
6 changes: 3 additions & 3 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def __init__(
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
Expand Down Expand Up @@ -688,7 +688,7 @@ def __init__(
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype
),
ResBlock(
ch,
Expand Down Expand Up @@ -742,7 +742,7 @@ def __init__(
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype
)
)
if level and i == self.num_res_blocks[level]:
Expand Down

0 comments on commit e21d9ad

Please sign in to comment.