-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_opts_tensor_CP.m
114 lines (104 loc) · 4.76 KB
/
run_opts_tensor_CP.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
function [out] = run_opts_tensor_CP(par)
s = RandStream('mt19937ar','Seed', par.initSeed);
RandStream.setGlobalStream(s);
lineSearch_ngmres=@(fg,u0,f0,g0,d) poblano_linesearch(fg,u0,f0,g0,1.,d,par.par_ngmres);
%------------------------------
% create a test tensor
% (this is the test from Tomasi and Bro (2006) and Acar et al. (2011) for
% a ramdom dense tensor modified to get specified collinearity between the
% columns)
%------------------------------
s=50; % size of each mode
r=3; % number of rank-one components in the original noise-free random tensor
l1=1; % magnitude of the first type of noise
l2=1; % magnitude of the second type of noise
c=0.9; % collinearity
pars=[s c r l1 l2];
%% TODO: Save pars in par (from get_par)
T=can_createTensorDense(pars);
%--------------------------------------
% do some pre-processing
%--------------------------------------
nT2=norm(T)^2;
fg = @(u) tt_cp_fun(u,T,nT2);
% Steepest descent preconditioner without linesearch
M_sd=@(u,f,g) descentDir(u,f,g, ...
fg, ... % function that computes f and g
par.precStep1,par.precStep2);
% ALS preconditioner
M_als=@(u,f,g) can_ALSu(T,r,u,f,g,fg);
%--------------------------------------
% generate the random initial guess
%--------------------------------------
u0=rand(r*sum(size(T)),1);
%--------------------------------------
% call ngmres-als
%--------------------------------------
if par.compareNGMRES_ALS==1
disp('+++ start n-gmres with ALS preconditioner')
out.out_ngmres_als=ngmres(u0, ... % initial guess
fg, ... % function that computes f and g
M_als, ... % preconditioner function
lineSearch_ngmres,... % line search function
par.par_ngmres); % ngmres parameters
end
%--------------------------------------
% call other solvers for comparison
%--------------------------------------
if par.compareNGMRESO_ALS==1
disp('+++ start n-gmres-o with ALS preconditioner')
out.out_ngmreso_als=ngmres_o(u0, ... % initial guess
fg, ... % function that computes f and g
M_als, ... % preconditioner function
lineSearch_ngmres,... % line search function
par.par_ngmres); % ngmres parameters
end
if par.compareNGMRES_sd==1; % then N-GMRES, preconditioner: SD with small step
disp('+++ start n-gmres with descent preconditioner')
out.out_ngmres_sd=ngmres(u0, ... % initial guess
fg, ... % function that computes f and g
M_sd, ... % preconditioner function
lineSearch_ngmres,... % line search function
par.par_ngmres); % ngmres parameters
end
if par.compareNGMRESO_sd==1; % then N-GMRES-O, preconditioner: SD with small step
disp('+++ start n-gmreso with descent preconditioner')
out.out_ngmreso_sd=ngmres_o(u0, ... % initial guess
fg, ... % function that computes f and g
M_sd, ... % preconditioner function
lineSearch_ngmres,... % line search function
par.par_ngmres); % ngmres parameters
end
if par.compareNCG==1
disp('+++ start n-cg directly')
out.out_ncg=ncg(fg,u0,par.par_ncg);
for i=2:size(out.out_ncg.TraceFuncEvals,2)
out.out_ncg.TraceFuncEvals(i)=out.out_ncg.TraceFuncEvals(i)+out.out_ncg.TraceFuncEvals(i-1);
end
end
if par.compareLBFGS==1
disp('+++ start lbfgs')
out.out_lbfgs=lbfgs(fg,u0,par.par_lbfgs);
for i=2:size(out.out_lbfgs.TraceFuncEvals,2)
out.out_lbfgs.TraceFuncEvals(i)=out.out_lbfgs.TraceFuncEvals(i)+out.out_lbfgs.TraceFuncEvals(i-1);
end
end
if par.compareALS==1
maxit = par.par_als.maxIt;
fevALS = zeros(maxit,1);
disp('+++ start ALS')
u=u0;
[f g]= fg(u);
for k=1:maxit
[u,f,g,fev]= M_als(u,f,g);%can_ALSu(T,r,u,f,g,fg);
fALS(k)=f;
gALS(k)=norm(g);
fevALS(k) = fev;
end
for k = 2:maxit
fevALS(k) = fevALS(k) + fevALS(k-1);
end
out.out_als.logf = fALS;
out.out_als.logfev = fevALS;
end
end