-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_comparison.m
97 lines (77 loc) · 2.87 KB
/
run_comparison.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
clear
close all
dataset = 'covtype'
load(['datasets_clean/', dataset]);
for j=2:size(X,2)
temp1=abs(X(:,j));
temp=max(temp1);
for i=1:size(X,1)
X(i,j)=X(i,j)/temp;
end
end
num_feature=size(X,2);
total_sample=size(y,1);
if dataset == "covtype"
mu = 1;
no_workers = 200;
rho=50;
alpha=1;
k = 20
elseif dataset == "SUSY"
mu = 1;
no_workers = 60;
rho=50;
alpha=1;
k = 20
elseif dataset == "cod-rna"
mu = 1;
no_workers = 60;
rho=50;
alpha=1;
k = 10
elseif dataset == "phishing" % iter = 11
mu = 1;
no_workers = 40;
rho=0.1;
alpha=0.25;
k = 17;
end
lambda_logistic = 1E-3;
num_iter = 100;
repeat = 3;
dataSamples_per_worker=floor(total_sample/no_workers);
total_sample =no_workers*dataSamples_per_worker;
X_fede=X;
y_fede=y;
[obj_snewton]=standard_newton...
(X_fede,y_fede, no_workers, num_feature, dataSamples_per_worker, num_iter,lambda_logistic);
obj0 = obj_snewton(end);
for i = 1:repeat
[obj_FedNS(:, i), loss_FedNS(:, i), ~]=FedNS...
(X_fede,y_fede, no_workers, num_feature, dataSamples_per_worker, num_iter, obj0, lambda_logistic, k, "Gaussian", mu);
[obj_FedNDES(:, i), loss_FedNDES(:, i), ~]=FedNDES...
(X_fede,y_fede, no_workers, num_feature, dataSamples_per_worker, num_iter, obj0, lambda_logistic, 2*k, "Gaussian", mu);
[obj_FedNewton(:, i), loss_FedNewton(:, i), ~]=FedNewton...
(X_fede,y_fede, no_workers, num_feature, dataSamples_per_worker, num_iter, obj0, lambda_logistic);
[obj_GD(:, i), loss_GD(:, i), ~]=GD...
(X_fede,y_fede, no_workers, num_feature, dataSamples_per_worker, num_iter, obj0, lambda_logistic);
[obj_znewton(:, i), loss_znewton(:, i), ~]=newton_zero...
(X_fede,y_fede, no_workers, num_feature, dataSamples_per_worker, num_iter, obj0, lambda_logistic);
[obj_newton_aadmm_Hk(:, i), loss_newton_aadmm_Hk(:, i), ~]=newton_ADMM_Hk...
(X_fede,y_fede, no_workers, num_feature, dataSamples_per_worker, num_iter, obj0, lambda_logistic, rho, alpha);
end
h = figure(1);
semilogy(mean(loss_GD, 2),'LineWidth',2);
hold on
semilogy(mean(loss_FedNewton, 2),'LineWidth', 2);
semilogy(mean(loss_znewton, 2),'k','LineWidth', 2);
semilogy(mean(loss_newton_aadmm_Hk, 2),'LineWidth',2);
semilogy(mean(loss_FedNS, 2),'LineWidth',2);
semilogy(mean(loss_FedNDES, 2),'LineWidth',2);
xlabel({'Number of communication rounds'},'fontsize',16,'fontname','Times New Roman')
ylabel('$f(x^t) - f(x^*)$','Interpreter','latex','fontsize', 16, 'fontweight','bold')
legend({'FedAvg', 'FedNewton', 'FedNL', 'FedNew', 'FedNS', 'FedNDES'},'fontsize', 16, 'Location', 'best');
ylim([1E-13 1E4])
set(gca,'fontsize',14,'fontweight','bold');
print(h, ['./results/', dataset, '.pdf'], '-dpdf','-r600')
save(['./results/', dataset, '_data.mat'],'loss_znewton', 'loss_newton_aadmm_Hk', 'loss_GD', 'loss_FedNewton', 'loss_FedNS', 'loss_FedNDES')