Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] feat: support unload_lora_weights() for Flux Control. #10206

Merged
merged 13 commits into from
Dec 25, 2024
56 changes: 56 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,50 @@ def unload_lora_weights(self):
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
transformer._transformer_norm_layers = None

if getattr(transformer, "_overwritten_params", None) is not None:
overwritten_params = transformer._overwritten_params
module_names = set()

for param_name in overwritten_params:
if param_name.endswith(".weight"):
module_names.add(param_name.replace(".weight", ""))

for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear) and name in module_names:
module_weight = module.weight.data
module_bias = module.bias.data if module.bias is not None else None
bias = module_bias is not None

parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)

current_param_weight = overwritten_params[f"{name}.weight"]
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0]
with torch.device("meta"):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already pin torch version this is safe enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_module = torch.nn.Linear(
in_features,
out_features,
bias=bias,
dtype=module_weight.dtype,
)

tmp_state_dict = {"weight": current_param_weight}
if module_bias is not None:
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
setattr(parent_module, current_module_name, original_module)
Comment on lines +2309 to +2320
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w thanks for flagging the device assignment while initializing original_module. device takes priority so original_module was not getting initialized on "meta", rending the previous copy_() ops ineffective.

LMK what you think about the current changes (have run the corresponding tests on a GPU and they pass).

@DN6 LMK your comments here too.


del tmp_state_dict

if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
new_value = int(current_param_weight.shape[1])
old_value = getattr(transformer.config, attribute_name)
setattr(transformer.config, attribute_name, new_value)
logger.info(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

@classmethod
def _maybe_expand_transformer_param_shape_or_error_(
cls,
Expand All @@ -2312,6 +2356,8 @@ def _maybe_expand_transformer_param_shape_or_error_(

# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
overwritten_params = {}

is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
Expand Down Expand Up @@ -2386,6 +2432,16 @@ def _maybe_expand_transformer_param_shape_or_error_(
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
)

# For `unload_lora_weights()`.
# TODO: this could lead to more memory overhead if the number of overwritten params
# are large. Should be revisited later and tackled through a `discard_original_layers` arg.
overwritten_params[f"{current_module_name}.weight"] = module_weight
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would have a small but significant memory overhead. For inference purposes only with loras, maybe this could be made opt-out if we know we never want call unload_lora_weights. Not a blocker though and can be tackled in a different PR but lmk your thoughts

Copy link
Member Author

@sayakpaul sayakpaul Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this could be tackled with discard_original_layers. For now, I have added a note as a comment about it.

if module_bias is not None:
overwritten_params[f"{current_module_name}.bias"] = module_bias

if len(overwritten_params) > 0:
transformer._overwritten_params = overwritten_params

return has_param_with_shape_update

@classmethod
Expand Down
66 changes: 66 additions & 0 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,72 @@ def test_load_regular_lora(self):
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))

def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)

logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)

# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
)

# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
components["transformer"] = transformer
pipe = FluxPipeline(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

_, _, inputs = self.get_dummy_inputs(with_generator=False)
control_image = inputs.pop("control_image")
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]

control_pipe = self.pipeline_class(**components)
out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
rank = 4

dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))

control_pipe.unload_lora_weights()
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
)
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
self.assertTrue(
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
)
inputs.pop("control_image")
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)

@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
Expand Down
Loading