StableDiffusionControlNetInpaintPipeline training script #6859
Replies: 2 comments 2 replies
-
Could you show how you're initializing things? |
Beta Was this translation helpful? Give feedback.
-
Hi @sayakpaul, Thank you very much for you answer! Sure, here it is how I am initializing things: I am using in my args this model:
Then, in the training script (as in the example in diffusers): # Weights of ip-adapter
state_dict_path = os.path.join(pretrained_path, ckpt_folder, "ip_adapter.bin")
...
# in the training loop, I do:
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter", subfolder="models/image_encoder", revision=args.revision, variant=args.variant
)
if args.controlnet_model_name_or_path:
logger.info("Loading existing controlnet weights")
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
else:
logger.info("Initializing controlnet weights from unet")
controlnet = ControlNetModel.from_unet(unet)
state_dict = torch.load(state_dict_path)
unet._load_ip_adapter_weights(state_dict) I am loading a ControlNetModel from a previous training with a model that was not for inpainting (SD 1.5, 4 channels instead of 9). The training loop seems to work fine. Then, when I log to wandb in the validation module, I do: # In the training loop:
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
image_logs = log_validation(
vae,
text_encoder,
tokenizer,
unet,
controlnet,
args,
accelerator,
weight_dtype,
global_step,
)
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
controlnet = accelerator.unwrap_model(controlnet)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
).to("cuda")
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
image_encoder=image_encoder
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.set_ip_adapter_scale(0.85)
logger.info("Loaded ip_adapter weights")
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) And it is here where I get the warning:
|
Beta Was this translation helpful? Give feedback.
-
Hi Diffusers,
I am training a CN model using the script train_controlnet.py from the original repo. I am doing it as an inpainting task (not text-to-image) with and ip-adapter also loaded. I am getting this warning:
{'controlnet', 'image_encoder'} was not found in config. Values will be initialized to default values.
after initializing my
StableDiffusionControlNetInpaintPipeline
pipeline. Is this expected? I had to initialize my ControlNet from a ckpt that has 4 channels and not 9 as my inpainting model, so I am not sure if this is the right way to do it.Beta Was this translation helpful? Give feedback.
All reactions