-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_FedAvg_FedDSV.py
65 lines (49 loc) · 2.8 KB
/
train_FedAvg_FedDSV.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import time
from util import *
import sys
# read the input random seed index
i_rand = int(sys.argv[1])
# the name of dataset and the number of providers
dataset_name = 'CIFAR10' # acceptable variables are ['MNIST', 'KMNIST', 'FMNIST', 'CIFAR10']
n_providers = 100
print('dataset: {}'.format(dataset_name))
print('n_providers: {}'.format(n_providers))
# load the random seed array
with open('rd_seed_array.pickle', 'rb') as f:
rd_seed_array = pickle.load(f)
# load the data
with open('data/' + dataset_name.lower() + '/corrupted_data_ndevices=' + str(n_providers) + '.pickle', 'rb') as f:
providers_train_list, val_loader, test_loader = pickle.load(f)
step_size_local = 0.1 # step size for local updates
step_size_global = step_size_local # step size for global updates
n_epoch = 1 # number of epochs for local computation of each round
batch_size = int(0.1 * n_providers) # number of providers chosen in each communication round
c_tol = 0.05 # convgernece tolerance to decide if DSV has converged
back_check = 10 # the gap to check convergence of DSV
if dataset_name == 'CIFAR10':
n_commun = 1000 # number of communications
else:
n_commun = 500 # number of communications
# set up random seed
print("random seed: {}".format(rd_seed_array[i_rand]))
np.random.seed(rd_seed_array[i_rand])
model, n_access, util_list, test_accu_list, time_used, FedDSV = train_FedAvg_FedDSV(providers_train_list, val_loader, test_loader, dataset_name, step_size_local, step_size_global, n_epoch, batch_size, n_commun, c_tol, back_check)
print("Time used: {}(s)".format(time_used))
# store the result
with open('result/' + dataset_name + '/FedAvg/FedDSV/model_nproviders=' + str(n_providers) + '_rdseed=' + str(rd_seed_array[i_rand]) + '.pickle', 'wb') as f:
pickle.dump(model, f)
with open('result/' + dataset_name + '/FedAvg/FedDSV/n_access_nproviders=' + str(n_providers) + '_rdseed=' + str(rd_seed_array[i_rand]) + '.pickle', 'wb') as f:
pickle.dump(n_access, f)
with open('result/' + dataset_name + '/FedAvg/FedDSV/util_list_nproviders=' + str(n_providers) + '_rdseed=' + str(rd_seed_array[i_rand]) + '.pickle', 'wb') as f:
pickle.dump(util_list, f)
with open('result/' + dataset_name + '/FedAvg/FedDSV/time_cost_nproviders=' + str(n_providers) + '_rdseed=' + str(rd_seed_array[i_rand]) + '.pickle', 'wb') as f:
pickle.dump(time_used, f)
with open('result/' + dataset_name + '/FedAvg/FedDSV/test_accu_list_nproviders=' + str(n_providers) + '_rdseed=' + str(rd_seed_array[i_rand]) + '.pickle', 'wb') as f:
pickle.dump(test_accu_list, f)
with open('result/' + dataset_name + '/FedAvg/FedDSV/FedDSV_nproviders=' + str(n_providers) + '_rdseed=' + str(rd_seed_array[i_rand]) + '.pickle', 'wb') as f:
pickle.dump(FedDSV, f)