-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT: Adding exclude modules param(#2044) #2102
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking on this feature. I haven't done an in-depth review yet, but from a quick look, I saw this:
- Right now, you only added the option to the
LoraConfig
. However, since the logic is implemented on thecheck_target_module_exists
function, actually all methods can benefit from this feature (not the prompt learning ones, as they don't havetarget_modules
, but all the rest like LoKr, BOFT, IA³, etc.). Let's add the argument to all relevant configs, i.e. the config of each method that hastarget_modules
. - We need to have tests for this new feature. Please check https://github.com/huggingface/peft/blob/main/tests/test_tuners_utils.py, which already contains a bunch of tests to check that the correct modules are targeted. Similar tests would need to be added for
exclude_modules
. Could you please take a look? Don't hesitate to ask if you have questions.
@BenjaminBossan thanks for the headsup, ill work on the points |
@BenjaminBossan ive added the test cases too, lmk what you think |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the latest changes. The tests are already shaping up quite nicely, but let's improve the test coverage a bit more (see my suggestion).
Moreover, you missed a few configs still. These are all I could find, but please double-check:
- AdaLoraConfig
- BoftConfig
- VBLoRAConfig
- FourierFTConfig
- LNTuningConfig
- PolyConfig
@@ -415,6 +415,35 @@ def test_realistic_example(self): | |||
] | |||
assert model.targeted_module_names == expected | |||
|
|||
class TestExcludedModuleNames(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get additional tests:
- Let's add tests for
exclude_modules
being a list of str. - What happens if all targeted modules are excluded? I think a nice error message would be the right call in this situation.
- What happens if we have excluded modules that don't exist? I think it's fine if some of the entries from
exclude_modules
don't match, as we also allow that fortarget_modules
. But what if none of the entries inexclude_modules
matches? I think this should also result in an error. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeap makes sense, let me work through the tests and edge cases, also the missed out modules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @BenjaminBossan i managed to add 1 and 2 cases. I wanted some help on how 3 should be done. As in, I see currently _check_target_module_exists is a private abstract method that every peft methods override and check for target modules existence. Now for exclude_modules, should we do the same? Or do i open a for loop within inject_adapter and check for keys there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to go that far. If we check closer, we can see e.g. for LoRA:
peft/src/peft/tuners/lora/model.py
Lines 158 to 160 in aa3bd8f
@staticmethod | |
def _check_target_module_exists(lora_config, key): | |
return check_target_module_exists(lora_config, key) |
And if we check where that's used, it all comes back to this line:
peft/src/peft/tuners/tuners_utils.py
Line 486 in aa3bd8f
if not self._check_target_module_exists(peft_config, key): |
So what would be an elegant way to implement some sanity checks here? I have one proposal, but maybe you have a better idea: Right now, check_target_module_exists
simply returns a bool, i.e. True
or False
if the module matched. We could extend that and, say, return a special object if a module was excluded due to exclude_modules
.
This special value could be a very simple object. Its __bool__
should return False
(so that users who rely on checking if not self._check_target_module_exists(peft_config, key)
don't suddenly get different results).
When we call _check_target_module_exists
, we can collect keys that did not match because of exclude_modules
in one list and keys that did not match for other reasons in another list. When we exit the loop, if no key was matched because they were all matching exclude_modules
, we can raise a nice error message to say so. Otherwise, we raise the existing error:
peft/src/peft/tuners/tuners_utils.py
Lines 506 to 510 in aa3bd8f
if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"): | |
raise ValueError( | |
f"Target modules {peft_config.target_modules} not found in the base model. " | |
f"Please check the target modules and try again." | |
) |
On the other hand, if not a single key was excluded due to exclude_modules
, even though exclude_modules
was passed by the user, we can give a nice warning saying something like "You passed exclude_modules=[...] but no key was matching this".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your further work on this PR. You raised a good question about how to implement the checks I mentioned. I made a proposal, but please LMK if you have a better idea.
src/peft/tuners/tuners_utils.py
Outdated
@@ -914,6 +914,20 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None: | |||
`bool` | `re.Match[str]` | `None`: True of match object if key matches any target modules from config, False or | |||
None if no match found | |||
""" | |||
if config.target_modules and config.exclude_modules: | |||
if config.exclude_modules == config.target_modules: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that this is a sufficient check. Just to give some examples:
target_modules
is a list of str andexclude_modules
is a str, but both essentially target the same modulestarget_modules
is["lin0", "lin1", "lin2"]
andexclude_modules
is["lin0", "lin1"]
. They look different, but if"lin2"
doesn't exist, they're actually the same
We thus need to check what modules were actually targeted, we can't rely on the passed arguments.
@@ -415,6 +415,35 @@ def test_realistic_example(self): | |||
] | |||
assert model.targeted_module_names == expected | |||
|
|||
class TestExcludedModuleNames(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to go that far. If we check closer, we can see e.g. for LoRA:
peft/src/peft/tuners/lora/model.py
Lines 158 to 160 in aa3bd8f
@staticmethod | |
def _check_target_module_exists(lora_config, key): | |
return check_target_module_exists(lora_config, key) |
And if we check where that's used, it all comes back to this line:
peft/src/peft/tuners/tuners_utils.py
Line 486 in aa3bd8f
if not self._check_target_module_exists(peft_config, key): |
So what would be an elegant way to implement some sanity checks here? I have one proposal, but maybe you have a better idea: Right now, check_target_module_exists
simply returns a bool, i.e. True
or False
if the module matched. We could extend that and, say, return a special object if a module was excluded due to exclude_modules
.
This special value could be a very simple object. Its __bool__
should return False
(so that users who rely on checking if not self._check_target_module_exists(peft_config, key)
don't suddenly get different results).
When we call _check_target_module_exists
, we can collect keys that did not match because of exclude_modules
in one list and keys that did not match for other reasons in another list. When we exit the loop, if no key was matched because they were all matching exclude_modules
, we can raise a nice error message to say so. Otherwise, we raise the existing error:
peft/src/peft/tuners/tuners_utils.py
Lines 506 to 510 in aa3bd8f
if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"): | |
raise ValueError( | |
f"Target modules {peft_config.target_modules} not found in the base model. " | |
f"Please check the target modules and try again." | |
) |
On the other hand, if not a single key was excluded due to exclude_modules
, even though exclude_modules
was passed by the user, we can give a nice warning saying something like "You passed exclude_modules=[...] but no key was matching this".
Hey, thanks for the suggestions. I think I fixed both the issues. Happy to work if more fixed needed, lmk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making these adjustments, we're getting closer to completing the PR. The general logic looks sound but I still have some comments about how the errors are raised. Please check my comments.
Also, please remember to call make style
before pushing.
src/peft/tuners/tuners_utils.py
Outdated
@@ -902,6 +912,12 @@ def generate_suffixes(s): | |||
return set(target_modules) | |||
return required_suffixes | |||
|
|||
class ExcludedModule: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this a "private" class: _ExcludedModule
. Also, let's add a short docstring to describe why it's needed.
src/peft/tuners/tuners_utils.py
Outdated
return False | ||
|
||
def __repr__(self): | ||
return "ExcludedModule()" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we need this?
src/peft/tuners/tuners_utils.py
Outdated
|
||
|
||
if hasattr(peft_config, "exclude_modules") and peft_config.exclude_modules and not excluded_modules: | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think raising an error here is a bit too much, let's switch to a warning instead. Also, I'd rephrase the message a bit: "You have passed exclude_modules=... but no modules were excluded, please check that exclude_modules was set correctly.".
src/peft/tuners/tuners_utils.py
Outdated
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) | ||
|
||
if len(self.targeted_module_names) == 0: | ||
# this means all the targeted modules have been excluded hence both are identical |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this check is quite correct. There could be other reasons why we have not targeted any modules. Instead, I think we need to check that no module was targeted, that all modules were excluded and that there were no unmatched modules.
We should also collect all checks and put them in a single if ... else
instead of having them in a few places, to increase readability. So the final code would be something like this:
if not self.targeted_modules and excluded_modules and not unmatched_modules:
# this means all targeted modules were excluded
raise ...
elif not self.targeted_modules and not excluded_modules and unmatched_modules:
# this means none of the targeted modules matched
raise ... # raise the old error message we already have
elif not self.targeted_modules:
# this means some modules did not match and some matched but were excluded
raise ... # there is no error for this yet, let's raise a nice error here
elif hasattr(peft_config, "exclude_modules") and peft_config.exclude_modules and not excluded_modules:
# this is the case that exclude_modules was passed but was useless
warnings.warn ...
Also, let's ensure that we have a test case for each of these situations on top of the tests that are already there.
gotcha, ive made the changes and i think it looks okayish now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great refactor, thanks a lot, this looks very clean now. I only found a small issue, otherwise this looks good.
src/peft/tuners/tuners_utils.py
Outdated
@@ -509,6 +544,11 @@ def inject_adapter( | |||
f"Please check the target modules and try again." | |||
) | |||
|
|||
if hasattr(peft_config, "exclude_modules") and peft_config.exclude_modules and not excluded_modules: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is redundant now with the one in line 524, right? Let's remove it.
yeap, done now :) |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Please run |
hmm weird, i actually ran make style already. is this because of windows? should i switch to a ubuntu and try again? |
I don't think it's a Windows issue. Do you perhaps have another ruff version installed? CI uses v0.6.8. |
yeah i think its the version, hopefully it works now |
Ah damn, now the doc-builder complains. Could you please run |
my bad , i didnt notice the doc-builder wasnt installed, it should be alright now |
i see the tests are failed due to a typeerror, this seems like an older python version run right? |
Ah, CI is just relentless :D Okay, this time the issue is with older Python versions, we can't do type annotations like |
sure lemme get on it |
done! |
You really seem to be on the bad side of the CI, now there is an issue with ruff again. It should be easy enough to address, please check. |
yeah seems like after adding the import, I had to lowercase the List Union[list[int], int] |
Of course that's not the end of it, now there is a merge conflict. It should be easy enough to fix but don't hesitate to ask if you have questions. Thanks A TON for your patience. |
haha no worries, hope this is okay now. |
hey, any clue why these tests are failing? |
Looks like it's unrelated, I'll investigate tomorrow. |
alrighty |
Okay, I understand now what's going on. The issue is that with X-LoRA, we don't have any targeted modules, which is fine, but we currently raise an error because of it. See this check: peft/src/peft/tuners/tuners_utils.py Lines 505 to 510 in 2a80735
I think the tests should pass if you move it up over |
ahh okay, like this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic still doesn't work quite right but the good news is that we can simplify it.
Right now, we check self.targeted_module_names
and also is_target_modules_in_base_model
, which is a bit redundant. So let's get rid of the latter.
Additionally, to ensure that X-LoRA does not raise an error (since there, it's okay to not have target modules), we need to check that we're not using dummy target modules (a special value used by X-LoRA). I sketched this out and this diff makes the tests pass for me:
modified src/peft/tuners/tuners_utils.py
@@ -437,13 +437,12 @@ class BaseTuner(nn.Module, ABC):
peft_config = self._prepare_adapter_config(peft_config, model_config)
self._prepare_model(peft_config, model)
- is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]
- if getattr(peft_config, "target_modules", None) == DUMMY_TARGET_MODULES:
+ uses_dummy_target_modules = getattr(peft_config, "target_modules", None) == DUMMY_TARGET_MODULES
+ if uses_dummy_target_modules:
# dummy adapter, we allow not matching any module
key_list = []
- is_target_modules_in_base_model = True
# update peft_config.target_modules if required
peft_config = _maybe_include_all_linear_layers(peft_config, model)
@@ -494,19 +493,12 @@ class BaseTuner(nn.Module, ABC):
unmatched_modules.append(key)
else:
self.targeted_module_names.append(key)
- is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(model, key)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
- # Handle X-LoRA case.
- if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
- raise ValueError(
- f"Target modules {peft_config.target_modules} not found in the base model. "
- f"Please check the target modules and try again."
- )
- elif not self.targeted_module_names:
+ if not self.targeted_module_names and not uses_dummy_target_modules:
if excluded_modules and not unmatched_modules:
# All targeted modules were excluded
raise ValueError(
thanks so much for the detailed suggestion! lets hope its okay now |
some unrelated tests failing right? i think with windows some tensor equality is not fulfilled? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, CI is failing for unrelated reasons. I'll just re-run until the flaky test passes :)
So this PR is basically finished, but when I did a final review, I saw a minor oversight, please check my comment. But then we should really be good to go!
@@ -119,3 +128,6 @@ def __post_init__(self): | |||
self.target_modules = ( | |||
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules | |||
) | |||
self.exclude_modules = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conversion to set
is missing in the LoKrConfig
.
yeap done! thanks for guiding all the way through! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this nice PR, good work.
Allows to exclude target modules.
This draft PR addresses the exclude modules parameter to keep specific layers out of target modules
Addresses #2044
Changes made:
@BenjaminBossan Can you have a look if this is the right direction? This is my first time contributing at peft , would love to have your guidance here :)