diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 4a030dab94..cbc21c1453 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -344,7 +344,9 @@ def _check_forward_args(self, x, *args, **kwargs): msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first." raise ValueError(msg) - unique_adapters = set(self.active_adapters) + # DoRA is not supported (yet), check that it's not being used. Don't check "__base__", as this is the + # placeholder for the base model. + unique_adapters = {name for name in adapter_names if name != "__base__"} for adapter_name in unique_adapters: if self.use_dora.get(adapter_name, False): msg = "Cannot pass `adapter_names` when DoRA is enabled." diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 3f3a97304f..0fa3bb823c 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -3485,7 +3485,7 @@ def test_mixed_adapter_batches_lora_merged_raises(self, mlp_lora): mlp_lora.forward(**inputs) def test_mixed_adapter_batches_lora_with_dora_raises(self): - # When there are Dora adapters, passing adapter names should raise an error + # When there are DoRA adapters, passing adapter names should raise an error torch.manual_seed(0) inputs = { "X": torch.arange(90).view(-1, 10).to(self.torch_device), @@ -3499,6 +3499,25 @@ def test_mixed_adapter_batches_lora_with_dora_raises(self): with pytest.raises(ValueError, match=msg): peft_model.forward(**inputs) + def test_mixed_adapter_batches_lora_with_dora_but_dora_not_included_works(self): + # When there are DoRA adapters, passing adapter names should raise an error, see previous test. However, when + # the adapter that uses DoRA is not included in adapter_names, it's actually fine. + torch.manual_seed(0) + base_model = MLP().to(self.torch_device).eval() + config_dora = LoraConfig(target_modules=["lin0"], init_lora_weights=False, use_dora=True) + peft_model = get_peft_model(base_model, config_dora) + config_no_dora = LoraConfig(target_modules=["lin0"], init_lora_weights=False, use_dora=False) + peft_model.add_adapter(adapter_name="other", peft_config=config_no_dora) + peft_model.eval() + + # The "default" adapter uses DoRA but "other" is not using it, so using "other" is fine. Also, "__base__" is + # fine since it uses the base model and thus DoRA is not involved either. + inputs = { + "X": torch.arange(90).view(-1, 10).to(self.torch_device), + "adapter_names": ["other"] * 4 + ["__base__"] * 5, + } + peft_model.forward(**inputs) + @require_non_cpu def test_mixed_adapter_batches_lora_opt_timing(self): # Use a more realistic model (opt-125m) and do a simple runtime check to ensure that mixed adapter batches