-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
69 lines (62 loc) · 2.73 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Global parameters
DATASETS = ['mnist', 'synthetic', 'emnist', 'fashionmnist', 'cifar', 'celeba']
MODELS = ['logistic', '2nn', '1nn', 'cifar', 'vgg', 'celebacnn', 'resnet', 'resnet2']
ALGORITHMS = {'me': 'MeClient',
'me_fair': 'MeFairClient',
'proj': 'ProjClient',
'proj_fair': 'ProjFairClient',
'ditto': 'DittoClient',
'lp': 'LpClient',
'lp_proj': 'LpProjClient',
'me_fair2': 'MeFair2Client',
'sketch': 'SketchClient',
'lg': 'LgClient',
'fedavg': 'FedAvgClient',
'perfedavg': 'PerFedAvgClient',
'local': 'LocalClient',
'lp_projnew': 'LpProjNewClient',
'lp_projdiff': 'LpProjDiffClient',
'lbgm': 'LBGMClient',
'qsgd': 'QSGDClient',
'dgc': 'DGCClient'}
OPTIMIZERS = ALGORITHMS.keys()
class ModelConfig(object):
def __init__(self):
pass
def __call__(self, dataset, model):
dataset = dataset.split('_')[0]
if dataset == 'mnist':
if model in ['logistic', '2nn', '1nn']:
return {'input_shape': 784, 'num_class': 10}
else:
return {'input_shape': (28, 28, 1), 'num_class': 10}
elif dataset == 'emnist':
if model in ['logistic', '2nn', '1nn']:
return {'input_shape': 784, 'num_class': 62}
else:
return {'input_shape': (28, 28, 1), 'num_class': 62}
elif dataset == 'fashionmnist':
if model in ['logistic', '2nn', '1nn']:
return {'input_shape': 784, 'num_class': 10}
elif model in ['resnet', 'resnet2']:
return {'input_shape': (224, 224, 1), 'num_class': 10}
elif dataset == 'synthetic':
return {'input_shape': 60, 'num_class': 10}
elif dataset == 'cifar':
return {'input_shape': (32, 32, 3), 'num_class': 10}
elif dataset == 'celeba':
return {'input_shape': (128, 128, 3), 'num_class': 2}
else:
raise ValueError('Not support dataset {}!'.format(dataset))
MODEL_PARAMS = ModelConfig()
CRITERIA = ['celoss', 'mseloss']
ATTACKS = ['same_value', 'sign_flip', 'gaussian', 'data_poison']
SERVERTYPE = {'server': 'Server',
'robust_server': 'RobustServer',
'server_sketch': 'ServerSketch',
'server_lg': 'ServerLg',
'server_local': 'ServerLocal',
'server_lbgm': 'ServerLBGM',
'server_gradient': 'ServerGradient'}
SERVERS = SERVERTYPE.keys()
AGGR = ['mean', 'median', 'krum']