Skip to content

Commit

Permalink
Don't unload/reload model from CPU uselessly.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Feb 8, 2023
1 parent e3e6594 commit a84cd0d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 36 deletions.
26 changes: 26 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@


current_loaded_model = None


def unload_model():
global current_loaded_model
if current_loaded_model is not None:
current_loaded_model.model.cpu()
current_loaded_model.unpatch_model()
current_loaded_model = None


def load_model_gpu(model):
global current_loaded_model
if model is current_loaded_model:
return
unload_model()
try:
real_model = model.patch_model()
except Exception as e:
model.unpatch_model()
raise e
current_loaded_model = model
real_model.cuda()
return current_loaded_model
3 changes: 3 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sd1_clip
import sd2_clip
import model_management
from ldm.util import instantiate_from_config
from ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf
Expand Down Expand Up @@ -304,6 +305,7 @@ def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=N
self.device = device

def decode(self, samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
samples = samples.to(self.device)
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
Expand All @@ -313,6 +315,7 @@ def decode(self, samples):
return pixel_samples

def encode(self, pixel_samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
Expand Down
69 changes: 33 additions & 36 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import comfy.samplers
import comfy.sd
import model_management

supported_ckpt_extensions = ['.ckpt']
supported_pt_extensions = ['.ckpt', '.pt', '.bin']
Expand Down Expand Up @@ -353,43 +354,39 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")

real_model = None
try:
if device != "cpu":
model_management.load_model_gpu(model)
real_model = model.model
else:
#TODO: cpu support
real_model = model.patch_model()
real_model.to(device)
noise = noise.to(device)
latent_image = latent_image.to(device)

positive_copy = []
negative_copy = []

for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
negative_copy += [[t] + n[1:]]

if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
else:
#other samplers
pass

samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
samples = samples.cpu()
real_model.cpu()
model.unpatch_model()
except Exception as e:
if real_model is not None:
real_model.cpu()
model.unpatch_model()
raise e
noise = noise.to(device)
latent_image = latent_image.to(device)

positive_copy = []
negative_copy = []

for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
negative_copy += [[t] + n[1:]]

if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
else:
#other samplers
pass

samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
samples = samples.cpu()

return (samples, )

Expand Down

0 comments on commit a84cd0d

Please sign in to comment.