-
Notifications
You must be signed in to change notification settings - Fork 0
/
attack.py
63 lines (49 loc) · 2.04 KB
/
attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Created by Ethan Ruan
# Email: [email protected]
from typing import List
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from arguement import parse_args
from common import tools, SememicUnit, HuggingFaceWrapper
from config import DEVICES, Pattern, ArgStyle, SeparatorType
from dataset import DataLoader
from segmentation import Separator
from validation import Validator
from perturbation_search import Searcher
from substitution import Substituter
from evaluation import Evaluator
if __name__ == '__main__':
args = parse_args()
# 处理来自命令行的参数
tools.setup_from_args(args)
# 初始化数据集加载器
data_loader = DataLoader()
# 初始化 被攻击的模型
victim_path = ArgStyle.Victim_Model[args.style]
classifier = AutoModelForSequenceClassification.from_pretrained(victim_path).to(DEVICES[1])
tokenizer = AutoTokenizer.from_pretrained(victim_path, use_fast = True)
victim_model = HuggingFaceWrapper(classifier, tokenizer)
# 初始化检验器
validator = Validator(victim_model)
# 初始化替代器
substituter = Substituter()
# 初始化搜索器
searcher = Searcher(validator, substituter)
# 初始化分词器
separator = Separator(SeparatorType.LTP)
# 初始化评价器
evaluator = Evaluator()
args_style = args.style
origin_examples = data_loader.generate_examples(args_style)
adv_text_list = validator.generate_adv_texts(origin_examples, args_style)
evaluator.set_origin_example_count(len(adv_text_list))
for index, adv_text in enumerate(adv_text_list):
tools.show_log(f'adv_text: {index} Round')
# 分词
substitute_units: List[SememicUnit] = separator.split(adv_text)
# 扰动查找
searcher.perform(substitute_units, adv_text)
# 收集评价指标信息
evaluator.add(adv_text.adversary_info)
tools.show_log(f'------------------------------------------------------------------------------------')
# 计算实验指标
evaluator.compute()