diff --git a/comfy/model_management.py b/comfy/model_management.py index ecbcabb0a77..e44c9e8a594 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -349,6 +349,15 @@ def text_encoder_device(): else: return torch.device("cpu") +def vae_device(): + return get_torch_device() + +def vae_offload_device(): + if args.gpu_only or vram_state == VRAMState.SHARED: + return get_torch_device() + else: + return torch.device("cpu") + def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type diff --git a/comfy/sd.py b/comfy/sd.py index 08d68c5f89b..3d79c7c04fb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -605,8 +605,9 @@ def __init__(self, ckpt_path=None, device=None, config=None): self.first_stage_model.load_state_dict(sd, strict=False) if device is None: - device = model_management.get_torch_device() + device = model_management.vae_device() self.device = device + self.offload_device = model_management.vae_offload_device() def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) @@ -651,7 +652,7 @@ def decode(self, samples_in): print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - self.first_stage_model = self.first_stage_model.cpu() + self.first_stage_model = self.first_stage_model.to(self.offload_device) pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples @@ -659,7 +660,7 @@ def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) - self.first_stage_model = self.first_stage_model.cpu() + self.first_stage_model = self.first_stage_model.to(self.offload_device) return output.movedim(1,-1) def encode(self, pixel_samples): @@ -679,7 +680,7 @@ def encode(self, pixel_samples): print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) - self.first_stage_model = self.first_stage_model.cpu() + self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -687,7 +688,7 @@ def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) - self.first_stage_model = self.first_stage_model.cpu() + self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def get_sd(self):