Skip to content

Commit

Permalink
--gpu-only now keeps the VAE on the device.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 1, 2023
1 parent ce35d8c commit 1c1b0e7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
9 changes: 9 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,15 @@ def text_encoder_device():
else:
return torch.device("cpu")

def vae_device():
return get_torch_device()

def vae_offload_device():
if args.gpu_only or vram_state == VRAMState.SHARED:
return get_torch_device()
else:
return torch.device("cpu")

def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
Expand Down
11 changes: 6 additions & 5 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,9 @@ def __init__(self, ckpt_path=None, device=None, config=None):
self.first_stage_model.load_state_dict(sd, strict=False)

if device is None:
device = model_management.get_torch_device()
device = model_management.vae_device()
self.device = device
self.offload_device = model_management.vae_offload_device()

def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
Expand Down Expand Up @@ -651,15 +652,15 @@ def decode(self, samples_in):
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)

self.first_stage_model = self.first_stage_model.cpu()
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples

def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.cpu()
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1)

def encode(self, pixel_samples):
Expand All @@ -679,15 +680,15 @@ def encode(self, pixel_samples):
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)

self.first_stage_model = self.first_stage_model.cpu()
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples

def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
self.first_stage_model = self.first_stage_model.cpu()
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples

def get_sd(self):
Expand Down

0 comments on commit 1c1b0e7

Please sign in to comment.