Skip to content

Commit

Permalink
Add CheckpointSave node to save checkpoints.
Browse files Browse the repository at this point in the history
The created checkpoints contain workflow metadata that can be loaded by
dragging them on top of the UI or loading them with the "Load" button.

Checkpoints will be saved in fp16 or fp32 depending on the format ComfyUI
is using for inference on your hardware. To force fp32 use: --force-fp32

Anything that patches the model weights like merging or loras will be
saved.

The output directory is currently set to: output/checkpoints but that might
change in the future.
  • Loading branch information
comfyanonymous committed Jun 26, 2023
1 parent b72a7a8 commit 9b93b92
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 13 deletions.
4 changes: 3 additions & 1 deletion comfy/diffusers_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,13 @@ def convert_vae_state_dict(vae_state_dict):
code2idx = {"q": 0, "k": 1, "v": 2}


def convert_text_enc_state_dict_v20(text_enc_dict):
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {}
capture_qkv_weight = {}
capture_qkv_bias = {}
for k, v in text_enc_dict.items():
if not k.startswith(prefix):
continue
if (
k.endswith(".self_attn.q_proj.weight")
or k.endswith(".self_attn.k_proj.weight")
Expand Down
12 changes: 12 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
from . import utils

class BaseModel(torch.nn.Module):
def __init__(self, model_config, v_prediction=False):
super().__init__()

unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
Expand Down Expand Up @@ -83,6 +85,16 @@ def process_latent_in(self, latent):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)

def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}


class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, v_prediction=True):
Expand Down
32 changes: 29 additions & 3 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,11 @@ def encode_from_tokens(self, tokens, return_pooled=False):
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
try:
self.patcher.patch_model()
self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.patcher.unpatch_model()
self.unpatch_model()
except Exception as e:
self.patcher.unpatch_model()
self.unpatch_model()
raise e

cond_out = cond
Expand All @@ -564,6 +564,15 @@ def encode(self, text):
def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd)

def get_sd(self):
return self.cond_stage_model.state_dict()

def patch_model(self):
self.patcher.patch_model()

def unpatch_model(self):
self.patcher.unpatch_model()

class VAE:
def __init__(self, ckpt_path=None, device=None, config=None):
if config is None:
Expand Down Expand Up @@ -665,6 +674,10 @@ def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
self.first_stage_model = self.first_stage_model.cpu()
return samples

def get_sd(self):
return self.first_stage_model.state_dict()


def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
Expand Down Expand Up @@ -1135,3 +1148,16 @@ class WeightsLoader(torch.nn.Module):
print("left over keys:", left_over)

return (ModelPatcher(model), clip, vae, clipvision)

def save_checkpoint(output_path, model, clip, vae, metadata=None):
try:
model.patch_model()
clip.patch_model()
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
utils.save_torch_file(sd, output_path, metadata=metadata)
model.unpatch_model()
clip.unpatch_model()
except Exception as e:
model.unpatch_model()
clip.unpatch_model()
raise e
29 changes: 29 additions & 0 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from . import supported_models_base
from . import latent_formats

from . import diffusers_convert

class SD15(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
Expand Down Expand Up @@ -63,6 +65,13 @@ def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return state_dict

def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix[""] = "cond_stage_model.model."
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict

def clip_target(self):
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)

Expand Down Expand Up @@ -113,6 +122,13 @@ def process_clip_state_dict(self, state_dict):
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict

def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g

def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)

Expand Down Expand Up @@ -142,6 +158,19 @@ def process_clip_state_dict(self, state_dict):
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
return state_dict

def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
for k in state_dict:
if k.startswith("clip_l"):
state_dict_g[k] = state_dict[k]

replace_prefix["clip_g"] = "conditioner.embedders.1.model"
replace_prefix["clip_l"] = "conditioner.embedders.0"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g

def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)

Expand Down
12 changes: 12 additions & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,15 @@ def get_model(self, state_dict):
def process_clip_state_dict(self, state_dict):
return state_dict

def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)

def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)

def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "first_stage_model."}
return state_dict_prefix_replace(state_dict, replace_prefix)

14 changes: 13 additions & 1 deletion comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import math
import struct
import comfy.checkpoint_pickle
import safetensors.torch

def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
if safe_load:
Expand All @@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False):
sd = pl_sd
return sd

def save_torch_file(sd, ckpt, metadata=None):
if metadata is not None:
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
else:
safetensors.torch.save_file(sd, ckpt)

def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",
Expand Down Expand Up @@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd

def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:
state_dict[k] = state_dict[k].to(dtype)
return state_dict

def safetensors_header(safetensors_path, max_size=100*1024*1024):
with open(safetensors_path, "rb") as f:
header = f.read(8)
Expand Down
44 changes: 42 additions & 2 deletions comfy_extras/nodes_model_merging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

import comfy.sd
import comfy.utils
import folder_paths
import json
import os

class ModelMergeSimple:
@classmethod
Expand Down Expand Up @@ -49,7 +53,43 @@ def merge(self, model1, model2, **kwargs):
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
return (m, )

class CheckpointSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()

@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True

CATEGORY = "_for_testing/model_merging"

def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)

metadata = {"prompt": prompt_info}
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])

output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {}


NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks
"ModelMergeBlocks": ModelMergeBlocks,
"CheckpointSave": CheckpointSave,
}
3 changes: 1 addition & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,7 @@ def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=No
output["latent_tensor"] = samples["samples"]
output["latent_format_version_0"] = torch.tensor([])

safetensors.torch.save_file(output, file, metadata=metadata)

comfy.utils.save_torch_file(output, file, metadata=metadata)
return {}


Expand Down
1 change: 1 addition & 0 deletions notebooks/comfyui_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
"\n",
"\n",
"# ESRGAN upscale model\n",
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion web/scripts/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,7 @@ export class ComfyApp {
this.loadGraphData(JSON.parse(reader.result));
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent")) {
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));
Expand Down
5 changes: 3 additions & 2 deletions web/scripts/pnginfo.js
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ export function getLatentMetadata(file) {
const dataView = new DataView(safetensorsData.buffer);
let header_size = dataView.getUint32(0, true);
let offset = 8;
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size)));
r(header.__metadata__);
};

reader.readAsArrayBuffer(file);
var slice = file.slice(0, 1024 * 1024 * 4);
reader.readAsArrayBuffer(slice);
});
}

Expand Down
2 changes: 1 addition & 1 deletion web/scripts/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ export class ComfyUI {
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png,.latent",
accept: ".json,image/png,.latent,.safetensors",
style: {display: "none"},
parent: document.body,
onchange: () => {
Expand Down

0 comments on commit 9b93b92

Please sign in to comment.