-
Notifications
You must be signed in to change notification settings - Fork 5
/
intermediate_layer.py
59 lines (52 loc) · 4.04 KB
/
intermediate_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import os
from attacks.intermediate_layer_attack import intermediate_layer_attack
from attacks.intermediate_layer_attack_imagenet import intermediate_layer_attack_imagenet
parser = argparse.ArgumentParser(description='MI attack besed on intermediate layers output.')
parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['mnist', 'cifar10', 'cifar100', 'imagenet'], help='Indicate dataset and target model. If you trained your own target model, the model choice will be overwritten')
parser.add_argument('-m', '--model_path', type=str, default='none', help='Indicate the path to the target model. If you used the train_target_model.py to train the model, leave this field to the default value.')
parser.add_argument('-a', '--attack_model', type=str, default='NN', choices=['NN', 'RF', 'XGBoost'], help='MI Attack model (default is NN).')
parser.add_argument('-s', '--sampling', type=str, default='none', choices=['none', 'undersampling', 'oversampling'], help='Indicate sampling. Useful for highly imbalaned cases.')
parser.add_argument('-c', '--attacker_knowledge', type=float, default=0.8, help='The portion of samples available to the attacker. Default is 0.8.')
parser.add_argument('-n', '--number_of_target_classes', type=int, default=0, help='Limit the MI attack to limited a number of classes for efficiency!')
parser.add_argument('-i', '--imagenet_path', type=str, default='../imagenet/', help='path to the imagenet dataset.')
parser.add_argument('-l', '--intermediate_layer', type=int, default=-1, help='Possible values: {-1, -2, -3}. May varies based on the target model')
parser.add_argument('--no_train_for_all', default=True, help='Disable training an attack model for all samples.', action='store_false')
parser.add_argument('--no_train_for_correctly_classified', default=True, help='Disable training a separate attack model for correctly labeled samples.', action='store_false')
parser.add_argument('--no_train_for_incorrect_misclassified', default=True, help='Disable training a separate attack model for misclassifed labeled samples.', action='store_false')
parser.add_argument('--verbose', default=False, help='Print full details.', action='store_true')
args = parser.parse_args()
if __name__ == '__main__':
verbose = args.verbose
dataset = args.dataset
model_name = args.model_path
intermediate_layer = args.intermediate_layer
attack_classifier = args.attack_model
sampling = args.sampling
what_portion_of_samples_attacker_knows = args.attacker_knowledge
if what_portion_of_samples_attacker_knows < 0.1 or what_portion_of_samples_attacker_knows > 0.9:
print('Error: Attacker knowledge should be in [0.1, 0.9] range!')
exit()
show_MI_attack = args.no_train_for_all
show_MI_attack_separate_result = args.no_train_for_correctly_classified
show_MI_attack_separate_result_for_incorrect = args.no_train_for_incorrect_misclassified
save_dir = os.path.join(os.getcwd(), 'saved_models')
num_classes = 10
if dataset == "mnist" or dataset == "cifar10":
num_classes = 10
num_targeted_classes = 10
elif dataset == "cifar100":
num_classes = 100
num_targeted_classes = 100
elif dataset == "imagenet":
num_classes = 1000
num_targeted_classes = 100
else:
print("Unknown dataset!")
exit()
if num_classes > args.number_of_target_classes > 0:
num_targeted_classes = args.number_of_target_classes
if dataset == "imagenet":
intermediate_layer_attack_imagenet(dataset, intermediate_layer, attack_classifier, sampling, what_portion_of_samples_attacker_knows, num_classes, num_targeted_classes, model_name, verbose, show_MI_attack, show_MI_attack_separate_result, show_MI_attack_separate_result_for_incorrect, args.imagenet_path)
else:
intermediate_layer_attack(dataset, intermediate_layer, attack_classifier, sampling, what_portion_of_samples_attacker_knows, num_classes, num_targeted_classes, model_name, verbose, show_MI_attack, show_MI_attack_separate_result, show_MI_attack_separate_result_for_incorrect)