diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628f..6ad6e763c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -477,7 +477,7 @@ def remove_model(old_ckpt_name): accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = unet.get_trainable_params() + params_to_clip = accelerator.unwrap_model(unet).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step()