From bf3f27177529cd7b82ba4776cf99ed90d090081e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Jun 2023 19:17:03 -0400 Subject: [PATCH] Add some nodes for basic model merging. --- comfy_extras/nodes_model_merging.py | 55 +++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 56 insertions(+) create mode 100644 comfy_extras/nodes_model_merging.py diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py new file mode 100644 index 00000000000..daf4b09baea --- /dev/null +++ b/comfy_extras/nodes_model_merging.py @@ -0,0 +1,55 @@ + + +class ModelMergeSimple: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, ratio): + m = model1.clone() + sd = model2.model_state_dict() + for k in sd: + m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) + return (m, ) + +class ModelMergeBlocks: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, **kwargs): + m = model1.clone() + sd = model2.model_state_dict() + default_ratio = next(iter(kwargs.values())) + + for k in sd: + ratio = default_ratio + k_unet = k[len("diffusion_model."):] + + for arg in kwargs: + if k_unet.startswith(arg): + ratio = kwargs[arg] + + m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ModelMergeSimple": ModelMergeSimple, + "ModelMergeBlocks": ModelMergeBlocks +} diff --git a/nodes.py b/nodes.py index cbb7d69ea0c..396abe30868 100644 --- a/nodes.py +++ b/nodes.py @@ -1459,4 +1459,5 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py")) load_custom_nodes()