-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathft_handlers.py
60 lines (45 loc) · 1.93 KB
/
ft_handlers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch.nn as nn
from collections import defaultdict, OrderedDict
"""
True base_model.model.classifier.original_module.dense.weight
True base_model.model.classifier.original_module.dense.bias
True base_model.model.classifier.original_module.out_proj.weight
True base_model.model.classifier.original_module.out_proj.bias
"""
class LoRAHandler(nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
def get_ft_parameters(self):
layer2lora_parameters = defaultdict(lambda: dict())
sd = self.base_model.state_dict()
for key, val in sd.items():
if 'lora_A.default' in key:
base_name = key.replace('.lora_A.default', '')
layer2lora_parameters[base_name]['A'] = val
elif 'lora_B.default' in key:
base_name = key.replace('.lora_B.default', '')
layer2lora_parameters[base_name]['B'] = val
task_parameters = {}
for name, key2val in layer2lora_parameters.items():
# A: [r, I]. B: [O, r]. BxA: [O,r]x[r,I]:[O,I].
task_parameters[name] = (key2val['B'] @ key2val['A'])
return OrderedDict(sorted(task_parameters.items()))
def get_model(self):
return self.base_model.get_base_model
class FFTHandler(nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
def get_ft_parameters(self):
return OrderedDict(sorted(self.base_model.state_dict().items()))
def get_final_model(self, **kwargs):
return self.base_model
class GeneralHandler(nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
def get_ft_parameters(self):
return OrderedDict(sorted(self.base_model.state_dict().items()))
def get_final_model(self, **kwargs):
return self.base_model