Skip to content

Commit

Permalink
FIX: Bug in find_minimal_target_modules (#2083)
Browse files Browse the repository at this point in the history
This bug was reported by Sayak and would occur if a required suffix had
itself as suffix a string that was already determined to be required, in
which case this required suffix would not be added.

The fix consists of prefixing a "." to the suffix before checking if it is
required or not.

On top of this, the algorithm has been changed to be deterministic.
Previously, it was not deterministic because a dictionary that was
looped over was built from a set, and sets don't guarantee order. This
would result in the loop being in arbitrary order.

As long as the algorithm is 100% correct, the order should not matter.
But in case we find bugs like this, the order does matter. We don't want
bugs to be flaky, therefore it is best to sort the dict and remove
randomness from the function.

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
BenjaminBossan and sayakpaul authored Sep 23, 2024
1 parent 5efeba1 commit b67c9b6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,14 +885,16 @@ def generate_suffixes(s):
# Initialize a set for required suffixes
required_suffixes = set()

for item, suffixes in target_modules_suffix_map.items():
# We sort the target_modules_suffix_map simply to get deterministic behavior, since sets have no order. In theory
# the order should not matter but in case there is a bug, it's better for the bug to be deterministic.
for item, suffixes in sorted(target_modules_suffix_map.items(), key=lambda tup: tup[1]):
# Go through target_modules items, shortest suffixes first
for suffix in suffixes:
# If the suffix is already in required_suffixes or matches other_module_names, skip it
if suffix in required_suffixes or suffix in other_module_suffixes:
continue
# Check if adding this suffix covers the item
if not any(item.endswith(req_suffix) for req_suffix in required_suffixes):
if not any(item.endswith("." + req_suffix) for req_suffix in required_suffixes):
required_suffixes.add(suffix)
break

Expand Down
45 changes: 45 additions & 0 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,3 +1282,48 @@ def test_get_peft_model_applies_find_target_modules(self):
# check that the resulting model is still the same
model_check_after = sum(p.sum() for p in model.parameters())
assert model_check_sum_before == model_check_after

def test_suffix_is_substring_of_other_suffix(self):
# This test is based on a real world bug found in diffusers. The issue was that we needed the suffix
# 'time_emb_proj' in the minimal target modules. However, if there already was the suffix 'proj' in the
# required_suffixes, 'time_emb_proj' would not be added because the test was `endswith(suffix)` and
# 'time_emb_proj' ends with 'proj'. The correct logic is to test if `endswith("." + suffix")`. The module names
# chosen here are only a subset of the hundreds of actual module names but this subset is sufficient to
# replicate the bug.
target_modules = [
"down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj",
"mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj",
"up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj",
"mid_block.attentions.0.proj_out",
"up_blocks.0.attentions.0.proj_out",
"down_blocks.1.attentions.0.proj_out",
"up_blocks.0.resnets.0.time_emb_proj",
"down_blocks.0.resnets.0.time_emb_proj",
"mid_block.resnets.0.time_emb_proj",
]
other_module_names = [
"conv_in",
"time_proj",
"time_embedding",
"time_embedding.linear_1",
"add_time_proj",
"add_embedding",
"add_embedding.linear_1",
"add_embedding.linear_2",
"down_blocks",
"down_blocks.0",
"down_blocks.0.resnets",
"down_blocks.0.resnets.0",
"up_blocks",
"up_blocks.0",
"up_blocks.0.attentions",
"up_blocks.0.attentions.0",
"up_blocks.0.attentions.0.norm",
"up_blocks.0.attentions.0.transformer_blocks",
"up_blocks.0.attentions.0.transformer_blocks.0",
"up_blocks.0.attentions.0.transformer_blocks.0.norm1",
"up_blocks.0.attentions.0.transformer_blocks.0.attn1",
]
expected = {"time_emb_proj", "proj", "proj_out"}
result = find_minimal_target_modules(target_modules, other_module_names)
assert result == expected

0 comments on commit b67c9b6

Please sign in to comment.