Skip to content

Commit

Permalink
Lower lora ram usage when in normal vram mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 16, 2023
1 parent 490771b commit 5f57362
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
20 changes: 10 additions & 10 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,9 @@ def unload_model():
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
model_accelerated = False


current_loaded_model.unpatch_model()
current_loaded_model.model.to(current_loaded_model.offload_device)
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model.unpatch_model()
current_loaded_model = None
if vram_state != VRAMState.HIGH_VRAM:
soft_empty_cache()
Expand Down Expand Up @@ -282,14 +281,6 @@ def load_model_gpu(model):
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
real_model.to(torch_dev)
else:
if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})

accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True

try:
real_model = model.patch_model()
Expand All @@ -298,6 +289,15 @@ def load_model_gpu(model):
unload_model()
raise e

if vram_set_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
elif vram_set_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True

return current_loaded_model

def load_controlnet_gpu(control_models):
Expand Down
12 changes: 9 additions & 3 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,17 @@ def calculate_weight(self, patches, weight, key):
return weight

def unpatch_model(self):
model_sd = self.model_state_dict()
keys = list(self.backup.keys())
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev

for k in keys:
model_sd[k][:] = self.backup[k]
del self.backup[k]
set_attr(self.model, k, self.backup[k])

self.backup = {}

Expand Down

0 comments on commit 5f57362

Please sign in to comment.