Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for different layer shapes for VeRA #1817

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/package_reference/vera.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ When saving the adapter parameters, it's possible to eschew storing the low rank

VeRA currently has the following constraints:

- All targeted parameters must have the same shape.
dkopi marked this conversation as resolved.
Show resolved Hide resolved
- Only `nn.Linear` layers are supported.
- Quantized layers are not supported.

Expand Down
48 changes: 18 additions & 30 deletions examples/sequence_classification/VeRA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
" task_type=\"SEQ_CLS\", \n",
" r=rank,\n",
" d_initial=0.1,\n",
" target_modules=[\"query\", \"value\"],\n",
" target_modules=[\"query\", \"value\", \"intermediate.dense\"],\n",
" save_projection=True,\n",
")\n",
"head_lr = 1e-2\n",
Expand Down Expand Up @@ -205,7 +205,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 610,754 || all params: 125,257,924 || trainable%: 0.48759709605278145\n"
"trainable params: 647,714 || all params: 125,294,884 || trainable%: 0.5170\n"
]
}
],
Expand Down Expand Up @@ -255,76 +255,76 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/29 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.24it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.33it/s]\n"
" 0%| | 0/29 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"100%|██████████| 29/29 [00:18<00:00, 1.58it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.52it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0: {'accuracy': 0.7132352941176471, 'f1': 0.823529411764706}\n"
"epoch 0: {'accuracy': 0.7475490196078431, 'f1': 0.8367670364500792}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.26it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.30it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.68it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1: {'accuracy': 0.7671568627450981, 'f1': 0.8484848484848485}\n"
"epoch 1: {'accuracy': 0.7671568627450981, 'f1': 0.8536209553158706}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.24it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.30it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.66it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.33it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 2: {'accuracy': 0.8259803921568627, 'f1': 0.8738898756660745}\n"
"epoch 2: {'accuracy': 0.8553921568627451, 'f1': 0.8959435626102292}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.25it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.41it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.64it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.35it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 3: {'accuracy': 0.8431372549019608, 'f1': 0.891156462585034}\n"
"epoch 3: {'accuracy': 0.8823529411764706, 'f1': 0.9133574007220215}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.25it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.35it/s]"
"100%|██████████| 29/29 [00:17<00:00, 1.63it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 4: {'accuracy': 0.8480392156862745, 'f1': 0.8938356164383561}\n"
"epoch 4: {'accuracy': 0.8897058823529411, 'f1': 0.9183303085299456}\n"
]
},
{
Expand Down Expand Up @@ -520,18 +520,6 @@
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
Expand Down
5 changes: 4 additions & 1 deletion src/peft/tuners/vera/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,12 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
vera_A = self.vera_A[active_adapter]
vera_B = self.vera_B[active_adapter]

sliced_A = vera_A[:, : self.in_features]
dkopi marked this conversation as resolved.
Show resolved Hide resolved
sliced_B = vera_B[: self.out_features, :]

dropout = self.vera_dropout[active_adapter]
x = x.to(lambda_d.dtype)
result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), vera_A), vera_B)
result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), sliced_A), sliced_B)

result = result.to(previous_dtype)
return result
Expand Down
29 changes: 12 additions & 17 deletions src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ class VeraModel(BaseTuner):
def __init__(self, model, config, adapter_name) -> None:
super().__init__(model, config, adapter_name)

def _find_first_dim(self, config) -> tuple[int, int]:
def _find_dim(self, config) -> tuple[int, int]:
"""
Finds the first linear layer that has been wrapped with Vera, and extract the input and output dimension.
Finds the largest input and output dimensions across linear layers that have been wrapped with VeRA.

This will be used for determining the size of the shared vera_A and vera_B matrices.

This will throw an error if there are multiple layers of the same type with different shapes.
"""
model_config = getattr(self.model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
Expand All @@ -116,7 +114,7 @@ def _find_first_dim(self, config) -> tuple[int, int]:
peft_config = self._prepare_adapter_config(config, model_config)
peft_config = _maybe_include_all_linear_layers(peft_config, self.model)

first_shape = None
largest_shape = None
for key, module in self.model.named_modules():
if not self._check_target_module_exists(peft_config, key):
continue
Expand All @@ -128,33 +126,30 @@ def _find_first_dim(self, config) -> tuple[int, int]:
else:
continue

if first_shape is None:
first_shape = module_shape
if largest_shape is None:
largest_shape = module_shape
continue

if module_shape != first_shape:
raise ValueError(
"Multiple target layers with different dimensions were specified. VeRA only supports a "
f"single dimension size. Expected shape {first_shape}, got {module_shape}."
)
if module_shape != largest_shape:
largest_shape = tuple(max(a, b) for a, b in zip(largest_shape, module_shape))

if first_shape is None:
if largest_shape is None:
msg = "No layers types compatible with VeRA were found. Please check `peft_config.target_modules`."
raise ValueError(msg)

return first_shape
return largest_shape

def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None:
first_linear_out_dim, first_linear_in_dim = self._find_first_dim(config)
linear_out_dim, linear_in_dim = self._find_dim(config)

# use of persistent to exclude vera_A and vera_B from the state dict if we choose not to save them.
self.vera_A = BufferDict({}, persistent=config.save_projection)
self.vera_B = BufferDict({}, persistent=config.save_projection)

# deterministic init of vera_A and vera_B if we know the key
generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key)
vera_A = _kaiming_init((config.r, first_linear_in_dim), generator=generator)
vera_B = _kaiming_init((first_linear_out_dim, config.r), generator=generator)
vera_A = _kaiming_init((config.r, linear_in_dim), generator=generator)
vera_B = _kaiming_init((linear_out_dim, config.r), generator=generator)
self.vera_A[adapter_name] = vera_A
self.vera_B[adapter_name] = vera_B

Expand Down
25 changes: 15 additions & 10 deletions tests/test_vera.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,18 @@ def test_vera_lambda_dont_share_memory(self, mlp_same_prng):
!= mlp_same_prng.base_model.model.lin2.vera_lambda_d["other"].data_ptr()
)

def test_vera_different_shapes_raises(self, mlp):
# It is not possible (currently) to have vera_A and vera_B for different shapes, as they cannot be shared if
# their shapes are not identical. lin0 and lin1 have different shapes.
config = VeraConfig(target_modules=["lin0", "lin1"], init_weights=False)
msg = re.escape(
"Multiple target layers with different dimensions were specified. VeRA only supports a single dimension "
"size. Expected shape (20, 10), got (20, 20)."
)
with pytest.raises(ValueError, match=msg):
get_peft_model(mlp, config)
def test_vera_different_shapes(self, mlp):
dkopi marked this conversation as resolved.
Show resolved Hide resolved
config = VeraConfig(target_modules=["lin0", "lin3"], init_weights=False)
mlp_different_shapes = get_peft_model(mlp, config)

vera_A = mlp_different_shapes.vera_A["default"]
vera_B = mlp_different_shapes.vera_B["default"]

# lin0 has the largest output dimension, lin3 has the largest input dimension
# vera_A should have the shape of (rank, largest_in), vera_B should have the shape of (largest_out, rank)
assert vera_A.shape == (config.r, mlp.lin3.in_features)
assert vera_B.shape == (mlp.lin0.out_features, config.r)
dkopi marked this conversation as resolved.
Show resolved Hide resolved

# should not raise
input = torch.randn(5, 10)
mlp_different_shapes(input)