-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathargs.py
78 lines (67 loc) · 3.84 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Optional, Dict, Sequence, Union, List
from dataclasses import dataclass, field
import transformers
import typing
@dataclass
class LorraArguments:
target_layers: str = field(metadata={"help": "Layers for Representation. Layers are seperate by `,` eg: `10,12,14,16,18,20` "})
transform_layers: str = field(metadata={"help": "Layers for Representation. Layers are seperate by `,` eg: `10,12,14,16,18,20` "})
lorra_alpha: float = field(default=5, metadata={"help": ""}) # LoRRA Hyperparameters
trainsets: typing.List[str] = field(default=None, metadata={"help": "A list of trainsets for finetuning the corresponding Concepts/Functions, separated by # for commandline inputs (eg: ['AlpacaSupervisedDataset'])"})
valsets: typing.List[str] = field(default=None, metadata={"help": "A list of valsets for finetuning the corresponding Concepts/Functions, separated by # for commandline inputs (eg: ['AlpacaSupervisedDataset'])"})
adv_string: str = field(default="", metadata={"help": "adversarial string for harmful prompt. (eg: Start with 'Sure here's')"})
full_layers: bool = field(
default=False, metadata={"help": "Whether to drop not used layer during training"}
)
def to_dict(self):
return dict(
target_layers=self.target_layers,
transform_layers=self.transform_layers,
lorra_alpha=self.lorra_alpha,
trainsets=self.trainsets,
valsets=self.valsets,
full_layers=self.full_layers
)
@dataclass
class LoraArguments:
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target_modules: Union[List[str], str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
lora_weight_path: str = ""
lora_bias: str = "none"
q_lora: bool = False
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-chat-hf")
adapter_name_or_path: str = field (
default=None, metadata={"help": "Adapater name"}
)
use_lora: bool = field(
default=False, metadata={"help": "Use LoRA (default: False)"}
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
grouped_to_max_length: bool = field (
default=False, metadata={"help": "Group to chunks of max length for pretraining"}
)
use_refusal_retain : Optional[bool] = field(default = True,
metadata={"help":"whether to train on refusal retain set for Llama models"})
sc_train_subset : Optional[List[str]] = field(default=None,
metadata={"help":"subset of the sc train set to train on"})
log_every : Optional[int] = field(default = 10,
metadata = {"help" : "log loss every log_every steps"})
sc_train_seq_type : Optional[str] = field(default = 'all_text',
metadata = {"help" : "what portion of the sequence to train on. can be all_text or assistant_response"})
coeff_schedule : Optional[str] = field(default = 'linear_converge',
metadata = {'help' : 'schedule for the coefficients. can be linear_converge or constant'})
sc_loss_type : Optional[str] = field(default = 'orig_act_dotprod',
metadata = {'help' : 'type of loss function for shortcircuiting. can be orig_act_dotprod, rand_vec_norm, pos_constant_rmu_{coeff}, center_constant_rmu_{coeff}'})