diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e69681611a4a..351295e938ff 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2286,6 +2286,50 @@ def unload_lora_weights(self): transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) transformer._transformer_norm_layers = None + if getattr(transformer, "_overwritten_params", None) is not None: + overwritten_params = transformer._overwritten_params + module_names = set() + + for param_name in overwritten_params: + if param_name.endswith(".weight"): + module_names.add(param_name.replace(".weight", "")) + + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear) and name in module_names: + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + current_param_weight = overwritten_params[f"{name}.weight"] + in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] + with torch.device("meta"): + original_module = torch.nn.Linear( + in_features, + out_features, + bias=bias, + dtype=module_weight.dtype, + ) + + tmp_state_dict = {"weight": current_param_weight} + if module_bias is not None: + tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) + original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) + setattr(parent_module, current_module_name, original_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(current_param_weight.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) + @classmethod def _maybe_expand_transformer_param_shape_or_error_( cls, @@ -2312,6 +2356,8 @@ def _maybe_expand_transformer_param_shape_or_error_( # Expand transformer parameter shapes if they don't match lora has_param_with_shape_update = False + overwritten_params = {} + is_peft_loaded = getattr(transformer, "peft_config", None) is not None for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): @@ -2386,6 +2432,16 @@ def _maybe_expand_transformer_param_shape_or_error_( f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." ) + # For `unload_lora_weights()`. + # TODO: this could lead to more memory overhead if the number of overwritten params + # are large. Should be revisited later and tackled through a `discard_original_layers` arg. + overwritten_params[f"{current_module_name}.weight"] = module_weight + if module_bias is not None: + overwritten_params[f"{current_module_name}.bias"] = module_bias + + if len(overwritten_params) > 0: + transformer._overwritten_params = overwritten_params + return has_param_with_shape_update @classmethod diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index b22fbaaed69b..0861160de6aa 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -558,6 +558,72 @@ def test_load_regular_lora(self): self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2) self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3)) + def test_lora_unload_with_parameter_expanded_shapes(self): + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) + + logger = logging.get_logger("diffusers.loaders.lora_pipeline") + logger.setLevel(logging.DEBUG) + + # Change the transformer config to mimic a real use case. + num_channels_without_control = 4 + transformer = FluxTransformer2DModel.from_config( + components["transformer"].config, in_channels=num_channels_without_control + ).to(torch_device) + self.assertTrue( + transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}", + ) + + # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`. + components["transformer"] = transformer + pipe = FluxPipeline(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + control_image = inputs.pop("control_image") + original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + + control_pipe = self.pipeline_class(**components) + out_features, in_features = control_pipe.transformer.x_embedder.weight.shape + rank = 4 + + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) + lora_state_dict = { + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, + } + with CaptureLogger(logger) as cap_logger: + control_pipe.load_lora_weights(lora_state_dict, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") + + inputs["control_image"] = control_image + lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features) + self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features) + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) + + control_pipe.unload_lora_weights() + self.assertTrue( + control_pipe.transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}", + ) + loaded_pipe = FluxPipeline.from_pipe(control_pipe) + self.assertTrue( + loaded_pipe.transformer.config.in_channels == num_channels_without_control, + f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}", + ) + inputs.pop("control_image") + unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0] + + self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4)) + self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4)) + self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features) + self.assertTrue(pipe.transformer.config.in_channels == in_features) + @unittest.skip("Not supported in Flux.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass