Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Adding exclude modules param(#2044) #2102

Merged
merged 20 commits into from
Oct 3, 2024

Conversation

JINO-ROHIT
Copy link
Contributor

This draft PR addresses the exclude modules parameter to keep specific layers out of target modules

Addresses #2044

Changes made:

  1. Changed LoraConfig class to take in exclude_modules param.
  2. Added the exclusion logic within the check_target_module_exists method.

@BenjaminBossan Can you have a look if this is the right direction? This is my first time contributing at peft , would love to have your guidance here :)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking on this feature. I haven't done an in-depth review yet, but from a quick look, I saw this:

  1. Right now, you only added the option to the LoraConfig. However, since the logic is implemented on the check_target_module_exists function, actually all methods can benefit from this feature (not the prompt learning ones, as they don't have target_modules, but all the rest like LoKr, BOFT, IA³, etc.). Let's add the argument to all relevant configs, i.e. the config of each method that has target_modules.
  2. We need to have tests for this new feature. Please check https://github.com/huggingface/peft/blob/main/tests/test_tuners_utils.py, which already contains a bunch of tests to check that the correct modules are targeted. Similar tests would need to be added for exclude_modules. Could you please take a look? Don't hesitate to ask if you have questions.

@JINO-ROHIT
Copy link
Contributor Author

@BenjaminBossan thanks for the headsup, ill work on the points

@JINO-ROHIT
Copy link
Contributor Author

@BenjaminBossan ive added the test cases too, lmk what you think

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the latest changes. The tests are already shaping up quite nicely, but let's improve the test coverage a bit more (see my suggestion).

Moreover, you missed a few configs still. These are all I could find, but please double-check:

  • AdaLoraConfig
  • BoftConfig
  • VBLoRAConfig
  • FourierFTConfig
  • LNTuningConfig
  • PolyConfig

@@ -415,6 +415,35 @@ def test_realistic_example(self):
]
assert model.targeted_module_names == expected

class TestExcludedModuleNames(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get additional tests:

  1. Let's add tests for exclude_modules being a list of str.
  2. What happens if all targeted modules are excluded? I think a nice error message would be the right call in this situation.
  3. What happens if we have excluded modules that don't exist? I think it's fine if some of the entries from exclude_modules don't match, as we also allow that for target_modules. But what if none of the entries in exclude_modules matches? I think this should also result in an error. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap makes sense, let me work through the tests and edge cases, also the missed out modules

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @BenjaminBossan i managed to add 1 and 2 cases. I wanted some help on how 3 should be done. As in, I see currently _check_target_module_exists is a private abstract method that every peft methods override and check for target modules existence. Now for exclude_modules, should we do the same? Or do i open a for loop within inject_adapter and check for keys there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to go that far. If we check closer, we can see e.g. for LoRA:

@staticmethod
def _check_target_module_exists(lora_config, key):
return check_target_module_exists(lora_config, key)

And if we check where that's used, it all comes back to this line:

if not self._check_target_module_exists(peft_config, key):

So what would be an elegant way to implement some sanity checks here? I have one proposal, but maybe you have a better idea: Right now, check_target_module_exists simply returns a bool, i.e. True or False if the module matched. We could extend that and, say, return a special object if a module was excluded due to exclude_modules.

This special value could be a very simple object. Its __bool__ should return False (so that users who rely on checking if not self._check_target_module_exists(peft_config, key) don't suddenly get different results).

When we call _check_target_module_exists, we can collect keys that did not match because of exclude_modules in one list and keys that did not match for other reasons in another list. When we exit the loop, if no key was matched because they were all matching exclude_modules, we can raise a nice error message to say so. Otherwise, we raise the existing error:

if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
f"Target modules {peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)

On the other hand, if not a single key was excluded due to exclude_modules, even though exclude_modules was passed by the user, we can give a nice warning saying something like "You passed exclude_modules=[...] but no key was matching this".

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your further work on this PR. You raised a good question about how to implement the checks I mentioned. I made a proposal, but please LMK if you have a better idea.

@@ -914,6 +914,20 @@ def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
`bool` | `re.Match[str]` | `None`: True of match object if key matches any target modules from config, False or
None if no match found
"""
if config.target_modules and config.exclude_modules:
if config.exclude_modules == config.target_modules:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that this is a sufficient check. Just to give some examples:

  • target_modules is a list of str and exclude_modules is a str, but both essentially target the same modules
  • target_modules is ["lin0", "lin1", "lin2"] and exclude_modules is ["lin0", "lin1"]. They look different, but if "lin2" doesn't exist, they're actually the same

We thus need to check what modules were actually targeted, we can't rely on the passed arguments.

@@ -415,6 +415,35 @@ def test_realistic_example(self):
]
assert model.targeted_module_names == expected

class TestExcludedModuleNames(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to go that far. If we check closer, we can see e.g. for LoRA:

@staticmethod
def _check_target_module_exists(lora_config, key):
return check_target_module_exists(lora_config, key)

And if we check where that's used, it all comes back to this line:

if not self._check_target_module_exists(peft_config, key):

So what would be an elegant way to implement some sanity checks here? I have one proposal, but maybe you have a better idea: Right now, check_target_module_exists simply returns a bool, i.e. True or False if the module matched. We could extend that and, say, return a special object if a module was excluded due to exclude_modules.

This special value could be a very simple object. Its __bool__ should return False (so that users who rely on checking if not self._check_target_module_exists(peft_config, key) don't suddenly get different results).

When we call _check_target_module_exists, we can collect keys that did not match because of exclude_modules in one list and keys that did not match for other reasons in another list. When we exit the loop, if no key was matched because they were all matching exclude_modules, we can raise a nice error message to say so. Otherwise, we raise the existing error:

if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
f"Target modules {peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)

On the other hand, if not a single key was excluded due to exclude_modules, even though exclude_modules was passed by the user, we can give a nice warning saying something like "You passed exclude_modules=[...] but no key was matching this".

@JINO-ROHIT
Copy link
Contributor Author

Hey, thanks for the suggestions. I think I fixed both the issues. Happy to work if more fixed needed, lmk

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making these adjustments, we're getting closer to completing the PR. The general logic looks sound but I still have some comments about how the errors are raised. Please check my comments.

Also, please remember to call make style before pushing.

@@ -902,6 +912,12 @@ def generate_suffixes(s):
return set(target_modules)
return required_suffixes

class ExcludedModule:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a "private" class: _ExcludedModule. Also, let's add a short docstring to describe why it's needed.

return False

def __repr__(self):
return "ExcludedModule()"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we need this?



if hasattr(peft_config, "exclude_modules") and peft_config.exclude_modules and not excluded_modules:
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think raising an error here is a bit too much, let's switch to a warning instead. Also, I'd rephrase the message a bit: "You have passed exclude_modules=... but no modules were excluded, please check that exclude_modules was set correctly.".

self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)

if len(self.targeted_module_names) == 0:
# this means all the targeted modules have been excluded hence both are identical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this check is quite correct. There could be other reasons why we have not targeted any modules. Instead, I think we need to check that no module was targeted, that all modules were excluded and that there were no unmatched modules.

We should also collect all checks and put them in a single if ... else instead of having them in a few places, to increase readability. So the final code would be something like this:

if not self.targeted_modules and excluded_modules and not unmatched_modules:
    # this means all targeted modules were excluded
    raise ...
elif not self.targeted_modules and not excluded_modules and unmatched_modules:
    # this means none of the targeted modules matched
    raise ...  # raise the old error message we already have
elif not self.targeted_modules:
    # this means some modules did not match and some matched but were excluded
    raise ...  # there is no error for this yet, let's raise a nice error here
elif hasattr(peft_config, "exclude_modules") and peft_config.exclude_modules and not excluded_modules:
    # this is the case that exclude_modules was passed but was useless
    warnings.warn ...

Also, let's ensure that we have a test case for each of these situations on top of the tests that are already there.

@JINO-ROHIT
Copy link
Contributor Author

gotcha, ive made the changes and i think it looks okayish now

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great refactor, thanks a lot, this looks very clean now. I only found a small issue, otherwise this looks good.

@@ -509,6 +544,11 @@ def inject_adapter(
f"Please check the target modules and try again."
)

if hasattr(peft_config, "exclude_modules") and peft_config.exclude_modules and not excluded_modules:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is redundant now with the one in line 524, right? Let's remove it.

@JINO-ROHIT
Copy link
Contributor Author

yeap, done now :)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Please run make style :)

@JINO-ROHIT
Copy link
Contributor Author

hmm weird, i actually ran make style already. is this because of windows? should i switch to a ubuntu and try again?

@BenjaminBossan
Copy link
Member

I don't think it's a Windows issue. Do you perhaps have another ruff version installed? CI uses v0.6.8.

@JINO-ROHIT
Copy link
Contributor Author

yeah i think its the version, hopefully it works now

@BenjaminBossan
Copy link
Member

Ah damn, now the doc-builder complains. Could you please run doc-builder style src/peft tests docs/source --max_len 119 --check_only. Normally, this should run as part of make style but I assume you ran ruff separately.

@JINO-ROHIT
Copy link
Contributor Author

my bad , i didnt notice the doc-builder wasnt installed, it should be alright now

@JINO-ROHIT
Copy link
Contributor Author

i see the tests are failed due to a typeerror, this seems like an older python version run right?

@BenjaminBossan
Copy link
Member

Ah, CI is just relentless :D Okay, this time the issue is with older Python versions, we can't do type annotations like list[str] unless we add the from __future__ import annotations import. At the very least in the LoHa config file, this is missing, but probably also in a few more. Could you please fix that. Sorry that I missed it.

@JINO-ROHIT
Copy link
Contributor Author

sure lemme get on it

@JINO-ROHIT
Copy link
Contributor Author

done!

@BenjaminBossan
Copy link
Member

You really seem to be on the bad side of the CI, now there is an issue with ruff again. It should be easy enough to address, please check.

@JINO-ROHIT
Copy link
Contributor Author

yeah seems like after adding the import, I had to lowercase the List Union[list[int], int]

@BenjaminBossan
Copy link
Member

Of course that's not the end of it, now there is a merge conflict. It should be easy enough to fix but don't hesitate to ask if you have questions.

Thanks A TON for your patience.

@JINO-ROHIT
Copy link
Contributor Author

haha no worries, hope this is okay now.

@JINO-ROHIT
Copy link
Contributor Author

hey, any clue why these tests are failing?

@BenjaminBossan
Copy link
Member

Looks like it's unrelated, I'll investigate tomorrow.

@JINO-ROHIT
Copy link
Contributor Author

alrighty

@BenjaminBossan
Copy link
Member

Okay, I understand now what's going on. The issue is that with X-LoRA, we don't have any targeted modules, which is fine, but we currently raise an error because of it. See this check:

# Handle X-LoRA case.
if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
raise ValueError(
f"Target modules {peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)

I think the tests should pass if you move it up over if not self.targeted_module_names in line 503, then change that line to elif not self.targeted_module_names.

@JINO-ROHIT
Copy link
Contributor Author

ahh okay, like this?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic still doesn't work quite right but the good news is that we can simplify it.

Right now, we check self.targeted_module_names and also is_target_modules_in_base_model, which is a bit redundant. So let's get rid of the latter.

Additionally, to ensure that X-LoRA does not raise an error (since there, it's okay to not have target modules), we need to check that we're not using dummy target modules (a special value used by X-LoRA). I sketched this out and this diff makes the tests pass for me:

modified   src/peft/tuners/tuners_utils.py
@@ -437,13 +437,12 @@ class BaseTuner(nn.Module, ABC):
         peft_config = self._prepare_adapter_config(peft_config, model_config)
 
         self._prepare_model(peft_config, model)
-        is_target_modules_in_base_model = False
         key_list = [key for key, _ in model.named_modules()]
 
-        if getattr(peft_config, "target_modules", None) == DUMMY_TARGET_MODULES:
+        uses_dummy_target_modules = getattr(peft_config, "target_modules", None) == DUMMY_TARGET_MODULES
+        if uses_dummy_target_modules:
             # dummy adapter, we allow not matching any module
             key_list = []
-            is_target_modules_in_base_model = True
 
         # update peft_config.target_modules if required
         peft_config = _maybe_include_all_linear_layers(peft_config, model)
@@ -494,19 +493,12 @@ class BaseTuner(nn.Module, ABC):
                 unmatched_modules.append(key)
             else:
                 self.targeted_module_names.append(key)
-                is_target_modules_in_base_model = True
                 parent, target, target_name = _get_submodules(model, key)
                 ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
                 with ctx():
                     self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
 
-        # Handle X-LoRA case.
-        if not is_target_modules_in_base_model and hasattr(peft_config, "target_modules"):
-            raise ValueError(
-                f"Target modules {peft_config.target_modules} not found in the base model. "
-                f"Please check the target modules and try again."
-            )
-        elif not self.targeted_module_names:
+        if not self.targeted_module_names and not uses_dummy_target_modules:
             if excluded_modules and not unmatched_modules:
                 # All targeted modules were excluded
                 raise ValueError(

@JINO-ROHIT
Copy link
Contributor Author

JINO-ROHIT commented Oct 2, 2024

thanks so much for the detailed suggestion!

lets hope its okay now

@JINO-ROHIT
Copy link
Contributor Author

some unrelated tests failing right? i think with windows some tensor equality is not fulfilled?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, CI is failing for unrelated reasons. I'll just re-run until the flaky test passes :)

So this PR is basically finished, but when I did a final review, I saw a minor oversight, please check my comment. But then we should really be good to go!

@@ -119,3 +128,6 @@ def __post_init__(self):
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)
self.exclude_modules = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion to set is missing in the LoKrConfig.

@JINO-ROHIT
Copy link
Contributor Author

yeap done! thanks for guiding all the way through!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this nice PR, good work.

@BenjaminBossan BenjaminBossan merged commit 8d9ecbe into huggingface:main Oct 3, 2024
14 checks passed
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this pull request Oct 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants