From 1b1ebe7fc1d08bff3943c4f9ff20aaec7edfd7a1 Mon Sep 17 00:00:00 2001 From: Dawid <20214809+dkopi@users.noreply.github.com> Date: Fri, 31 May 2024 23:11:47 +0200 Subject: [PATCH 1/5] add support for adapting different layer shapes with VeRA --- docs/source/package_reference/vera.md | 1 - src/peft/tuners/vera/layer.py | 5 ++++- src/peft/tuners/vera/model.py | 29 +++++++++++---------------- tests/test_vera.py | 25 ++++++++++++++--------- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/docs/source/package_reference/vera.md b/docs/source/package_reference/vera.md index 9677df2742..0b4d2904f5 100644 --- a/docs/source/package_reference/vera.md +++ b/docs/source/package_reference/vera.md @@ -22,7 +22,6 @@ When saving the adapter parameters, it's possible to eschew storing the low rank VeRA currently has the following constraints: -- All targeted parameters must have the same shape. - Only `nn.Linear` layers are supported. - Quantized layers are not supported. diff --git a/src/peft/tuners/vera/layer.py b/src/peft/tuners/vera/layer.py index e6c2e2ee1f..8978b52072 100644 --- a/src/peft/tuners/vera/layer.py +++ b/src/peft/tuners/vera/layer.py @@ -252,9 +252,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: vera_A = self.vera_A[active_adapter] vera_B = self.vera_B[active_adapter] + sliced_A = vera_A[:, : self.in_features] + sliced_B = vera_B[: self.out_features, :] + dropout = self.vera_dropout[active_adapter] x = x.to(lambda_d.dtype) - result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), vera_A), vera_B) + result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), sliced_A), sliced_B) result = result.to(previous_dtype) return result diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py index 2ecd1c9ab8..a47112d94e 100644 --- a/src/peft/tuners/vera/model.py +++ b/src/peft/tuners/vera/model.py @@ -101,13 +101,11 @@ class VeraModel(BaseTuner): def __init__(self, model, config, adapter_name) -> None: super().__init__(model, config, adapter_name) - def _find_first_dim(self, config) -> tuple[int, int]: + def _find_dim(self, config) -> tuple[int, int]: """ - Finds the first linear layer that has been wrapped with Vera, and extract the input and output dimension. + Finds the largest input and output dimensions across linear layers that have been wrapped with VeRA. This will be used for determining the size of the shared vera_A and vera_B matrices. - - This will throw an error if there are multiple layers of the same type with different shapes. """ model_config = getattr(self.model, "config", {"model_type": "custom"}) if hasattr(model_config, "to_dict"): @@ -116,7 +114,7 @@ def _find_first_dim(self, config) -> tuple[int, int]: peft_config = self._prepare_adapter_config(config, model_config) peft_config = _maybe_include_all_linear_layers(peft_config, self.model) - first_shape = None + largest_shape = None for key, module in self.model.named_modules(): if not self._check_target_module_exists(peft_config, key): continue @@ -128,24 +126,21 @@ def _find_first_dim(self, config) -> tuple[int, int]: else: continue - if first_shape is None: - first_shape = module_shape + if largest_shape is None: + largest_shape = module_shape continue - if module_shape != first_shape: - raise ValueError( - "Multiple target layers with different dimensions were specified. VeRA only supports a " - f"single dimension size. Expected shape {first_shape}, got {module_shape}." - ) + if module_shape != largest_shape: + largest_shape = tuple(max(a, b) for a, b in zip(largest_shape, module_shape)) - if first_shape is None: + if largest_shape is None: msg = "No layers types compatible with VeRA were found. Please check `peft_config.target_modules`." raise ValueError(msg) - return first_shape + return largest_shape def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None: - first_linear_out_dim, first_linear_in_dim = self._find_first_dim(config) + linear_out_dim, linear_in_dim = self._find_dim(config) # use of persistent to exclude vera_A and vera_B from the state dict if we choose not to save them. self.vera_A = BufferDict({}, persistent=config.save_projection) @@ -153,8 +148,8 @@ def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None: # deterministic init of vera_A and vera_B if we know the key generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) - vera_A = _kaiming_init((config.r, first_linear_in_dim), generator=generator) - vera_B = _kaiming_init((first_linear_out_dim, config.r), generator=generator) + vera_A = _kaiming_init((config.r, linear_in_dim), generator=generator) + vera_B = _kaiming_init((linear_out_dim, config.r), generator=generator) self.vera_A[adapter_name] = vera_A self.vera_B[adapter_name] = vera_B diff --git a/tests/test_vera.py b/tests/test_vera.py index 9fd3eca71c..10e3fa35b7 100644 --- a/tests/test_vera.py +++ b/tests/test_vera.py @@ -265,13 +265,18 @@ def test_vera_lambda_dont_share_memory(self, mlp_same_prng): != mlp_same_prng.base_model.model.lin2.vera_lambda_d["other"].data_ptr() ) - def test_vera_different_shapes_raises(self, mlp): - # It is not possible (currently) to have vera_A and vera_B for different shapes, as they cannot be shared if - # their shapes are not identical. lin0 and lin1 have different shapes. - config = VeraConfig(target_modules=["lin0", "lin1"], init_weights=False) - msg = re.escape( - "Multiple target layers with different dimensions were specified. VeRA only supports a single dimension " - "size. Expected shape (20, 10), got (20, 20)." - ) - with pytest.raises(ValueError, match=msg): - get_peft_model(mlp, config) + def test_vera_different_shapes(self, mlp): + config = VeraConfig(target_modules=["lin0", "lin3"], init_weights=False) + mlp_different_shapes = get_peft_model(mlp, config) + + vera_A = mlp_different_shapes.vera_A["default"] + vera_B = mlp_different_shapes.vera_B["default"] + + # lin0 has the largest output dimension, lin3 has the largest input dimension + # vera_A should have the shape of (rank, largest_in), vera_B should have the shape of (largest_out, rank) + assert vera_A.shape == (config.r, mlp.lin3.in_features) + assert vera_B.shape == (mlp.lin0.out_features, config.r) + + # should not raise + input = torch.randn(5, 10) + mlp_different_shapes(input) From 17f4d074d94983014cfc882095c838fdfa0234ee Mon Sep 17 00:00:00 2001 From: Dawid <20214809+dkopi@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:13:19 +0200 Subject: [PATCH 2/5] update example notebook for vera with different layer shapes --- examples/sequence_classification/VeRA.ipynb | 48 ++++++++------------- 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/examples/sequence_classification/VeRA.ipynb b/examples/sequence_classification/VeRA.ipynb index b917618db3..e3786fff45 100644 --- a/examples/sequence_classification/VeRA.ipynb +++ b/examples/sequence_classification/VeRA.ipynb @@ -94,7 +94,7 @@ " task_type=\"SEQ_CLS\", \n", " r=rank,\n", " d_initial=0.1,\n", - " target_modules=[\"query\", \"value\"],\n", + " target_modules=[\"query\", \"value\", \"intermediate.dense\"],\n", " save_projection=True,\n", ")\n", "head_lr = 1e-2\n", @@ -205,7 +205,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "trainable params: 610,754 || all params: 125,257,924 || trainable%: 0.48759709605278145\n" + "trainable params: 647,714 || all params: 125,294,884 || trainable%: 0.5170\n" ] } ], @@ -255,76 +255,76 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/29 [00:00 Date: Sat, 8 Jun 2024 17:42:00 +0200 Subject: [PATCH 3/5] adress requested changes --- docs/source/package_reference/vera.md | 2 ++ src/peft/tuners/vera/layer.py | 8 +++++++- tests/test_custom_models.py | 1 + tests/test_vera.py | 3 +++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/source/package_reference/vera.md b/docs/source/package_reference/vera.md index 0b4d2904f5..9f7bb19a38 100644 --- a/docs/source/package_reference/vera.md +++ b/docs/source/package_reference/vera.md @@ -20,6 +20,8 @@ rendered properly in your Markdown viewer. When saving the adapter parameters, it's possible to eschew storing the low rank matrices by setting `save_projection=False` on the `VeraConfig`. In that case, these matrices will be restored based on the fixed random seed from the `projection_prng_key` argument. This cuts down on the size of the checkpoint, but we cannot guarantee reproducibility on all devices and for all future versions of PyTorch. If you want to ensure reproducibility, set `save_projection=True` (which is the default). +To handle different shapes of adapted layers, VeRA initializes shared A and B matrices with the largest required size for each dimension. During the forward pass, submatrices A and B for a given layer are sliced out from these shared matrices and used as described in the paper. For example, adapting two linear layers of shapes (100, 20) and (80, 50) will create A and B matrices of shapes (rank, 50) and (100, rank) respectively. Then, to adapt a layer of shape (100, 20), submatrices A and B of shapes (rank, 20) and (100, rank) will be extracted. + VeRA currently has the following constraints: - Only `nn.Linear` layers are supported. diff --git a/src/peft/tuners/vera/layer.py b/src/peft/tuners/vera/layer.py index 8978b52072..1d18bd55bc 100644 --- a/src/peft/tuners/vera/layer.py +++ b/src/peft/tuners/vera/layer.py @@ -217,9 +217,12 @@ def get_delta_weight(self, adapter) -> torch.Tensor: lambda_d = lambda_d.float() lambda_b = lambda_b.float() + + sliced_A = vera_A[:, : self.in_features] + sliced_B = vera_B[: self.out_features, :] lambda_b = lambda_b.unsqueeze(-1) lambda_d = lambda_d.unsqueeze(-1) - output_tensor = transpose((lambda_b * vera_B) @ (lambda_d * vera_A), self.fan_in_fan_out) + output_tensor = transpose((lambda_b * sliced_B) @ (lambda_d * sliced_A), self.fan_in_fan_out) if cast_to_fp32: output_tensor = output_tensor.to(dtype=dtype) @@ -252,6 +255,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: vera_A = self.vera_A[active_adapter] vera_B = self.vera_B[active_adapter] + # As adapted layers may have different shapes and VeRA contains a single shared pair of A and B matrices, + # we initialize these matrices with the largest required size for each dimension. + # During the forward pass, required submatrices are sliced out from the shared vera_A and vera_B. sliced_A = vera_A[:, : self.in_features] sliced_B = vera_B[: self.out_features, :] diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index c67c3e2d35..99f4d2f898 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -338,6 +338,7 @@ ("Vanilla MLP 1 VeRA", "MLP", VeraConfig, {"target_modules": "lin0"}), ("Vanilla MLP 2 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0"]}), ("Vanilla MLP 3 VeRA", "MLP", VeraConfig, {"target_modules": ["lin1"]}), + ("Vanilla MLP 4 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0", "lin1"]}), ( "Vanilla MLP 5 VeRA", "MLP", diff --git a/tests/test_vera.py b/tests/test_vera.py index 10e3fa35b7..3308bf7964 100644 --- a/tests/test_vera.py +++ b/tests/test_vera.py @@ -272,6 +272,9 @@ def test_vera_different_shapes(self, mlp): vera_A = mlp_different_shapes.vera_A["default"] vera_B = mlp_different_shapes.vera_B["default"] + # sanity check + assert mlp.lin0.base_layer.weight.shape != mlp.lin3.base_layer.weight.shape + # lin0 has the largest output dimension, lin3 has the largest input dimension # vera_A should have the shape of (rank, largest_in), vera_B should have the shape of (largest_out, rank) assert vera_A.shape == (config.r, mlp.lin3.in_features) From b2373de3b9678d0e9b06310f4ec5a5a6886ff73e Mon Sep 17 00:00:00 2001 From: Dawid <20214809+dkopi@users.noreply.github.com> Date: Mon, 10 Jun 2024 14:03:44 +0200 Subject: [PATCH 4/5] remove unused import --- tests/test_vera.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_vera.py b/tests/test_vera.py index 3308bf7964..6dbaac6bd3 100644 --- a/tests/test_vera.py +++ b/tests/test_vera.py @@ -15,7 +15,6 @@ # This test file is for tests specific to VeRA, since VeRA has some specific challenges due to the shared weights. import os -import re import pytest import torch From 9dd86b9ec59aeac98b510c2b5f61c90494715afe Mon Sep 17 00:00:00 2001 From: Dawid <20214809+dkopi@users.noreply.github.com> Date: Mon, 10 Jun 2024 14:11:06 +0200 Subject: [PATCH 5/5] remove additional line break --- src/peft/tuners/vera/layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/peft/tuners/vera/layer.py b/src/peft/tuners/vera/layer.py index 1d18bd55bc..115e5305ef 100644 --- a/src/peft/tuners/vera/layer.py +++ b/src/peft/tuners/vera/layer.py @@ -217,7 +217,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor: lambda_d = lambda_d.float() lambda_b = lambda_b.float() - sliced_A = vera_A[:, : self.in_features] sliced_B = vera_B[: self.out_features, :] lambda_b = lambda_b.unsqueeze(-1)