Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for different layer shapes for VeRA #1817

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/package_reference/vera.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ rendered properly in your Markdown viewer.

When saving the adapter parameters, it's possible to eschew storing the low rank matrices by setting `save_projection=False` on the `VeraConfig`. In that case, these matrices will be restored based on the fixed random seed from the `projection_prng_key` argument. This cuts down on the size of the checkpoint, but we cannot guarantee reproducibility on all devices and for all future versions of PyTorch. If you want to ensure reproducibility, set `save_projection=True` (which is the default).

To handle different shapes of adapted layers, VeRA initializes shared A and B matrices with the largest required size for each dimension. During the forward pass, submatrices A and B for a given layer are sliced out from these shared matrices and used as described in the paper. For example, adapting two linear layers of shapes (100, 20) and (80, 50) will create A and B matrices of shapes (rank, 50) and (100, rank) respectively. Then, to adapt a layer of shape (100, 20), submatrices A and B of shapes (rank, 20) and (100, rank) will be extracted.

VeRA currently has the following constraints:

- All targeted parameters must have the same shape.
dkopi marked this conversation as resolved.
Show resolved Hide resolved
- Only `nn.Linear` layers are supported.
- Quantized layers are not supported.

Expand Down
48 changes: 18 additions & 30 deletions examples/sequence_classification/VeRA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
" task_type=\"SEQ_CLS\", \n",
" r=rank,\n",
" d_initial=0.1,\n",
" target_modules=[\"query\", \"value\"],\n",
" target_modules=[\"query\", \"value\", \"intermediate.dense\"],\n",
" save_projection=True,\n",
")\n",
"head_lr = 1e-2\n",
Expand Down Expand Up @@ -205,7 +205,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 610,754 || all params: 125,257,924 || trainable%: 0.48759709605278145\n"
"trainable params: 647,714 || all params: 125,294,884 || trainable%: 0.5170\n"
]
}
],
Expand Down Expand Up @@ -255,76 +255,76 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/29 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.24it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.33it/s]\n"
" 0%| | 0/29 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"100%|██████████| 29/29 [00:18<00:00, 1.58it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.52it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0: {'accuracy': 0.7132352941176471, 'f1': 0.823529411764706}\n"
"epoch 0: {'accuracy': 0.7475490196078431, 'f1': 0.8367670364500792}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.26it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.30it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.68it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1: {'accuracy': 0.7671568627450981, 'f1': 0.8484848484848485}\n"
"epoch 1: {'accuracy': 0.7671568627450981, 'f1': 0.8536209553158706}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.24it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.30it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.66it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.33it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 2: {'accuracy': 0.8259803921568627, 'f1': 0.8738898756660745}\n"
"epoch 2: {'accuracy': 0.8553921568627451, 'f1': 0.8959435626102292}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.25it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.41it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.64it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.35it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 3: {'accuracy': 0.8431372549019608, 'f1': 0.891156462585034}\n"
"epoch 3: {'accuracy': 0.8823529411764706, 'f1': 0.9133574007220215}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.25it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.35it/s]"
"100%|██████████| 29/29 [00:17<00:00, 1.63it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 4: {'accuracy': 0.8480392156862745, 'f1': 0.8938356164383561}\n"
"epoch 4: {'accuracy': 0.8897058823529411, 'f1': 0.9183303085299456}\n"
]
},
{
Expand Down Expand Up @@ -520,18 +520,6 @@
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
Expand Down
12 changes: 10 additions & 2 deletions src/peft/tuners/vera/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
lambda_d = lambda_d.float()
lambda_b = lambda_b.float()

sliced_A = vera_A[:, : self.in_features]
sliced_B = vera_B[: self.out_features, :]
lambda_b = lambda_b.unsqueeze(-1)
lambda_d = lambda_d.unsqueeze(-1)
output_tensor = transpose((lambda_b * vera_B) @ (lambda_d * vera_A), self.fan_in_fan_out)
output_tensor = transpose((lambda_b * sliced_B) @ (lambda_d * sliced_A), self.fan_in_fan_out)

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
Expand Down Expand Up @@ -252,9 +254,15 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
vera_A = self.vera_A[active_adapter]
vera_B = self.vera_B[active_adapter]

# As adapted layers may have different shapes and VeRA contains a single shared pair of A and B matrices,
# we initialize these matrices with the largest required size for each dimension.
# During the forward pass, required submatrices are sliced out from the shared vera_A and vera_B.
sliced_A = vera_A[:, : self.in_features]
dkopi marked this conversation as resolved.
Show resolved Hide resolved
sliced_B = vera_B[: self.out_features, :]

dropout = self.vera_dropout[active_adapter]
x = x.to(lambda_d.dtype)
result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), vera_A), vera_B)
result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), sliced_A), sliced_B)

result = result.to(previous_dtype)
return result
Expand Down
29 changes: 12 additions & 17 deletions src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ class VeraModel(BaseTuner):
def __init__(self, model, config, adapter_name) -> None:
super().__init__(model, config, adapter_name)

def _find_first_dim(self, config) -> tuple[int, int]:
def _find_dim(self, config) -> tuple[int, int]:
"""
Finds the first linear layer that has been wrapped with Vera, and extract the input and output dimension.
Finds the largest input and output dimensions across linear layers that have been wrapped with VeRA.

This will be used for determining the size of the shared vera_A and vera_B matrices.

This will throw an error if there are multiple layers of the same type with different shapes.
"""
model_config = getattr(self.model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
Expand All @@ -116,7 +114,7 @@ def _find_first_dim(self, config) -> tuple[int, int]:
peft_config = self._prepare_adapter_config(config, model_config)
peft_config = _maybe_include_all_linear_layers(peft_config, self.model)

first_shape = None
largest_shape = None
for key, module in self.model.named_modules():
if not self._check_target_module_exists(peft_config, key):
continue
Expand All @@ -128,33 +126,30 @@ def _find_first_dim(self, config) -> tuple[int, int]:
else:
continue

if first_shape is None:
first_shape = module_shape
if largest_shape is None:
largest_shape = module_shape
continue

if module_shape != first_shape:
raise ValueError(
"Multiple target layers with different dimensions were specified. VeRA only supports a "
f"single dimension size. Expected shape {first_shape}, got {module_shape}."
)
if module_shape != largest_shape:
largest_shape = tuple(max(a, b) for a, b in zip(largest_shape, module_shape))

if first_shape is None:
if largest_shape is None:
msg = "No layers types compatible with VeRA were found. Please check `peft_config.target_modules`."
raise ValueError(msg)

return first_shape
return largest_shape

def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None:
first_linear_out_dim, first_linear_in_dim = self._find_first_dim(config)
linear_out_dim, linear_in_dim = self._find_dim(config)

# use of persistent to exclude vera_A and vera_B from the state dict if we choose not to save them.
self.vera_A = BufferDict({}, persistent=config.save_projection)
self.vera_B = BufferDict({}, persistent=config.save_projection)

# deterministic init of vera_A and vera_B if we know the key
generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key)
vera_A = _kaiming_init((config.r, first_linear_in_dim), generator=generator)
vera_B = _kaiming_init((first_linear_out_dim, config.r), generator=generator)
vera_A = _kaiming_init((config.r, linear_in_dim), generator=generator)
vera_B = _kaiming_init((linear_out_dim, config.r), generator=generator)
self.vera_A[adapter_name] = vera_A
self.vera_B[adapter_name] = vera_B

Expand Down
1 change: 1 addition & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@
("Vanilla MLP 1 VeRA", "MLP", VeraConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 VeRA", "MLP", VeraConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0", "lin1"]}),
(
"Vanilla MLP 5 VeRA",
"MLP",
Expand Down
Loading
Loading