diff --git a/comfy/model_management.py b/comfy/model_management.py index e148408b889..21f7c71867e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -258,15 +258,11 @@ def load_model_gpu(model): if model is current_loaded_model: return unload_model() - try: - real_model = model.patch_model() - except Exception as e: - model.unpatch_model() - raise e torch_dev = model.load_device model.model_patches_to(torch_dev) model.model_patches_to(model.model_dtype()) + current_loaded_model = model if is_device_cpu(torch_dev): vram_set_state = VRAMState.DISABLED @@ -280,8 +276,7 @@ def load_model_gpu(model): if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM - current_loaded_model = model - + real_model = model.model if vram_set_state == VRAMState.DISABLED: pass elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: @@ -295,6 +290,14 @@ def load_model_gpu(model): accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) model_accelerated = True + + try: + real_model = model.patch_model() + except Exception as e: + model.unpatch_model() + unload_model() + raise e + return current_loaded_model def load_controlnet_gpu(control_models): diff --git a/comfy/sd.py b/comfy/sd.py index 125b15b7703..e30fae16c15 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -340,7 +340,7 @@ def patch_model(self): weight = model_sd[key] if key not in self.backup: - self.backup[key] = weight.clone() + self.backup[key] = weight.to(self.offload_device, copy=True) temp_weight = weight.to(torch.float32, copy=True) weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) @@ -367,15 +367,16 @@ def calculate_weight(self, patches, weight, key): else: weight += alpha * w1.type(weight.dtype).to(weight.device) elif len(v) == 4: #lora/locon - mat1 = v[0] - mat2 = v[1] + mat1 = v[0].float().to(weight.device) + mat2 = v[1].float().to(weight.device) if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) - weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + mat3 = v[3].float().to(weight.device) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) elif len(v) == 8: #lokr w1 = v[0] w2 = v[1] @@ -389,20 +390,24 @@ def calculate_weight(self, patches, weight, key): if w1 is None: dim = w1_b.shape[0] w1 = torch.mm(w1_a.float(), w1_b.float()) + else: + w1 = w1.float().to(weight.device) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(w2_a.float(), w2_b.float()) + w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float()) + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + else: + w2 = w2.float().to(weight.device) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) if v[2] is not None and dim is not None: alpha *= v[2] / dim - weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device) + weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) else: #loha w1a = v[0] w1b = v[1] @@ -413,13 +418,13 @@ def calculate_weight(self, patches, weight, key): if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float()) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float()) + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) else: - m1 = torch.mm(w1a.float(), w1b.float()) - m2 = torch.mm(w2a.float(), w2b.float()) + m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) + m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) - weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) + weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) return weight def unpatch_model(self): diff --git a/comfy/utils.py b/comfy/utils.py index 956ac1773fd..d410e6af604 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -4,18 +4,20 @@ import comfy.checkpoint_pickle import safetensors.torch -def load_torch_file(ckpt, safe_load=False): +def load_torch_file(ckpt, safe_load=False, device=None): + if device is None: + device = torch.device("cpu") if ckpt.lower().endswith(".safetensors"): - sd = safetensors.torch.load_file(ckpt, device="cpu") + sd = safetensors.torch.load_file(ckpt, device=device.type) else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") safe_load = False if safe_load: - pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) + pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: - pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle) + pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: