-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] feat: support unload_lora_weights()
for Flux Control.
#10206
Changes from all commits
d44f39c
c541d74
4509f34
fafce0c
98a57fe
456e975
2f05455
6ed1131
52c55c1
90469b3
2d61c1b
58a3748
c53d7a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2286,6 +2286,50 @@ def unload_lora_weights(self): | |
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) | ||
transformer._transformer_norm_layers = None | ||
|
||
if getattr(transformer, "_overwritten_params", None) is not None: | ||
overwritten_params = transformer._overwritten_params | ||
module_names = set() | ||
|
||
for param_name in overwritten_params: | ||
if param_name.endswith(".weight"): | ||
module_names.add(param_name.replace(".weight", "")) | ||
|
||
for name, module in transformer.named_modules(): | ||
if isinstance(module, torch.nn.Linear) and name in module_names: | ||
module_weight = module.weight.data | ||
module_bias = module.bias.data if module.bias is not None else None | ||
bias = module_bias is not None | ||
|
||
parent_module_name, _, current_module_name = name.rpartition(".") | ||
parent_module = transformer.get_submodule(parent_module_name) | ||
|
||
current_param_weight = overwritten_params[f"{name}.weight"] | ||
in_features, out_features = current_param_weight.shape[1], current_param_weight.shape[0] | ||
with torch.device("meta"): | ||
original_module = torch.nn.Linear( | ||
in_features, | ||
out_features, | ||
bias=bias, | ||
dtype=module_weight.dtype, | ||
) | ||
|
||
tmp_state_dict = {"weight": current_param_weight} | ||
if module_bias is not None: | ||
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]}) | ||
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True) | ||
setattr(parent_module, current_module_name, original_module) | ||
Comment on lines
+2309
to
+2320
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @a-r-r-o-w thanks for flagging the LMK what you think about the current changes (have run the corresponding tests on a GPU and they pass). @DN6 LMK your comments here too. |
||
|
||
del tmp_state_dict | ||
|
||
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: | ||
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] | ||
new_value = int(current_param_weight.shape[1]) | ||
old_value = getattr(transformer.config, attribute_name) | ||
setattr(transformer.config, attribute_name, new_value) | ||
logger.info( | ||
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." | ||
) | ||
|
||
@classmethod | ||
def _maybe_expand_transformer_param_shape_or_error_( | ||
cls, | ||
|
@@ -2312,6 +2356,8 @@ def _maybe_expand_transformer_param_shape_or_error_( | |
|
||
# Expand transformer parameter shapes if they don't match lora | ||
has_param_with_shape_update = False | ||
overwritten_params = {} | ||
|
||
is_peft_loaded = getattr(transformer, "peft_config", None) is not None | ||
for name, module in transformer.named_modules(): | ||
if isinstance(module, torch.nn.Linear): | ||
|
@@ -2386,6 +2432,16 @@ def _maybe_expand_transformer_param_shape_or_error_( | |
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." | ||
) | ||
|
||
# For `unload_lora_weights()`. | ||
# TODO: this could lead to more memory overhead if the number of overwritten params | ||
# are large. Should be revisited later and tackled through a `discard_original_layers` arg. | ||
overwritten_params[f"{current_module_name}.weight"] = module_weight | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this would have a small but significant memory overhead. For inference purposes only with loras, maybe this could be made opt-out if we know we never want call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this could be tackled with |
||
if module_bias is not None: | ||
overwritten_params[f"{current_module_name}.bias"] = module_bias | ||
|
||
if len(overwritten_params) > 0: | ||
transformer._overwritten_params = overwritten_params | ||
|
||
return has_param_with_shape_update | ||
|
||
@classmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we already pin
torch
version this is safe enough.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also cc: @a-r-r-o-w. Something we should consider doing in:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_pipeline.py#L2351-L2354