Skip to content

Commit

Permalink
Merge branch 'comfyanonymous:master' into socketrework
Browse files Browse the repository at this point in the history
  • Loading branch information
pythongosssss authored Feb 25, 2023
2 parents 9f391ab + 1144ed5 commit a9c5784
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
- [Area Composition](https://comfyanonymous.github.io/ComfyUI_examples/area_composition/)
- [Inpainting](https://comfyanonymous.github.io/ComfyUI_examples/inpaint/) with both regular and inpainting models.
- [ControlNet](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/)
- [ControlNet](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) and T2I-Adapter
- Starts up very fast.
- Works fully offline: will never download anything.

Expand Down
3 changes: 0 additions & 3 deletions comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#taken from: https://github.com/lllyasviel/ControlNet
#and modified

import einops
import torch
import torch as th
import torch.nn as nn
Expand All @@ -13,8 +12,6 @@
timestep_embedding,
)

from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion
Expand Down
16 changes: 11 additions & 5 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,17 +774,23 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwarg
emb = emb + self.label_emb(y)

h = x.type(self.dtype)
for module in self.input_blocks:
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h)
h = self.middle_block(h, emb, context)
if control is not None:
h += control.pop()
if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop()

for module in self.output_blocks:
hsp = hs.pop()
if control is not None:
hsp += control.pop()
if control is not None and 'output' in control and len(control['output']) > 0:
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
h = th.cat([h, hsp], dim=1)
del hsp
h = module(h, emb, context)
Expand Down
139 changes: 135 additions & 4 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf
from .cldm import cldm
from .t2i_adapter import adapter

from . import utils

Expand Down Expand Up @@ -318,6 +319,37 @@ def decode(self, samples):
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples

def decode_tiled(self, samples):
tile_x = tile_y = 64
overlap = 8
model_management.unload_model()
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu")
self.first_stage_model = self.first_stage_model.to(self.device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
out_div = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:,:,y:y+tile_y,x:x+tile_x]

pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * s_in.to(self.device))
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
ps = pixel_samples.cpu()
mask = torch.ones_like(ps)
feather = overlap * 8
for t in range(feather):
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
out[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += ps * mask
out_div[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += mask

output[b:b+1] = out/out_div
self.first_stage_model = self.first_stage_model.cpu()
return output.movedim(1,-1)

def encode(self, pixel_samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
Expand Down Expand Up @@ -357,18 +389,28 @@ def get_control(self, x_noisy, t, cond_txt):
self.control_model = model_management.load_if_low_vram(self.control_model)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = []
out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled()

for i in range(len(control)):
if i == (len(control) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control[i]
x *= self.strength
if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype)

if control_prev is not None:
x += control_prev[i]
out.append(x)
if control_prev is not None and key in control_prev:
prev = control_prev[key][index]
if prev is not None:
x += prev
out[key].append(x)
if control_prev is not None and 'input' in control_prev:
out['input'] = control_prev['input']
return out

def set_cond_hint(self, cond_hint, strength=1.0):
Expand Down Expand Up @@ -463,6 +505,95 @@ class WeightsLoader(torch.nn.Module):
control = ControlNet(control_model)
return control

class T2IAdapter:
def __init__(self, t2i_model, channels_in, device="cuda"):
self.t2i_model = t2i_model
self.channels_in = channels_in
self.strength = 1.0
self.device = device
self.previous_controlnet = None
self.control_input = None
self.cond_hint_original = None
self.cond_hint = None

def get_control(self, x_noisy, t, cond_txt):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt)

if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint)
self.t2i_model.cpu()

output_dtype = x_noisy.dtype
out = {'input':[]}

for i in range(len(self.control_input)):
key = 'input'
x = self.control_input[i] * self.strength
if x.dtype != output_dtype and not autocast_enabled:
x = x.to(output_dtype)

if control_prev is not None and key in control_prev:
index = len(control_prev[key]) - i * 3 - 3
prev = control_prev[key][index]
if prev is not None:
x += prev
out[key].insert(0, None)
out[key].insert(0, None)
out[key].insert(0, x)

if control_prev is not None and 'input' in control_prev:
for i in range(len(out['input'])):
if out['input'][i] is None:
out['input'][i] = control_prev['input'][i]
if control_prev is not None and 'middle' in control_prev:
out['middle'] = control_prev['middle']
if control_prev is not None and 'output' in control_prev:
out['output'] = control_prev['output']
return out

def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
self.strength = strength
return self

def set_previous_controlnet(self, controlnet):
self.previous_controlnet = controlnet
return self

def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in)
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
return c

def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None

def get_control_models(self):
out = []
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models()
return out

def load_t2i_adapter(ckpt_path, model=None):
t2i_data = load_torch_file(ckpt_path)
cin = t2i_data['conv_in.weight'].shape[1]
model_ad = adapter.Adapter(cin=cin, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
model_ad.load_state_dict(t2i_data)
return T2IAdapter(model_ad, cin // 64)

def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path)
Expand Down
125 changes: 125 additions & 0 deletions comfy/t2i_adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#taken from https://github.com/TencentARC/T2I-Adapter

import torch
import torch.nn as nn
import torch.nn.functional as F
from ldm.modules.attention import SpatialTransformer, BasicTransformerBlock

def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")

def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")

class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)


class ResnetBlock(nn.Module):
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
super().__init__()
ps = ksize//2
if in_c != out_c or sk==False:
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
# print('n_in')
self.in_conv = None
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
if sk==False:
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
self.skep = None

self.down = down
if self.down == True:
self.down_opt = Downsample(in_c, use_conv=use_conv)

def forward(self, x):
if self.down == True:
x = self.down_opt(x)
if self.in_conv is not None: # edit
x = self.in_conv(x)

h = self.block1(x)
h = self.act(h)
h = self.block2(h)
if self.skep is not None:
return h + self.skep(x)
else:
return h + x


class Adapter(nn.Module):
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
super(Adapter, self).__init__()
self.unshuffle = nn.PixelUnshuffle(8)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
for i in range(len(channels)):
for j in range(nums_rb):
if (i!=0) and (j==0):
self.body.append(ResnetBlock(channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
else:
self.body.append(ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
self.body = nn.ModuleList(self.body)
self.conv_in = nn.Conv2d(cin,channels[0], 3, 1, 1)

def forward(self, x):
# unshuffle
x = self.unshuffle(x)
# extract features
features = []
x = self.conv_in(x)
for i in range(len(self.channels)):
for j in range(self.nums_rb):
idx = i*self.nums_rb +j
x = self.body[idx](x)
features.append(x)

return features
Empty file.
Loading

0 comments on commit a9c5784

Please sign in to comment.