-
Notifications
You must be signed in to change notification settings - Fork 1
/
temp.rs
150 lines (134 loc) · 4.47 KB
/
temp.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#![allow(non_snake_case)]
use super::Options;
use num_complex::Complex;
const TWO_PI: f64 = std::f64::consts::TAU;
pub fn horner_eval_f(coeffs: &[f64], zval: f64) -> f64 {
coeffs.iter().fold(0.0, |acc, coeff| acc * zval + coeff)
}
pub fn horner_eval_c(coeffs: &[f64], zval: &Complex<f64>) -> Complex<f64> {
coeffs
.iter()
.fold(Complex::<f64>::new(0.0, 0.0), |acc, coeff| {
acc * zval + coeff
})
}
pub fn initial_aberth(coeffs: &[f64]) -> Vec<Complex<f64>> {
let degree = coeffs.len() - 1;
let center = -coeffs[1] / (coeffs[0] * degree as f64);
let poly_c = horner_eval_f(coeffs, center);
let radius = Complex::<f64>::new(-poly_c, 0.0).powf(1.0 / degree as f64);
let k = TWO_PI / (degree as f64);
(0..degree)
.map(|idx| {
let theta = k * (0.25 + idx as f64);
center + radius * Complex::<f64>::new(theta.cos(), theta.sin())
})
.collect()
}
/// Aberth's method
pub fn aberth(coeffs: &[f64], zs: &mut [Complex<f64>], options: &Options) -> (usize, bool) {
let m_zs = zs.len();
let degree = coeffs.len() - 1; // degree, assume even
// let coeffs1: Vec<_> = (0..degree)
// .map(|i| coeffs[i] * (degree - i) as f64)
// .collect();
let coeffs1: Vec<_> = coeffs[0..degree]
.iter()
.enumerate()
.map(|(i, ci)| ci * (degree - i) as f64)
.collect();
let mut converged = vec![false; m_zs];
for niter in 0..options.max_iters {
let mut tolerance = 0.0;
for i in 0..m_zs {
if converged[i] {
continue;
}
let mut zi = zs[i];
if let Some(tol_i) = aberth_job(coeffs, i, &mut zi, &mut converged[i], zs, &coeffs1) {
if tolerance < tol_i {
tolerance = tol_i;
}
}
zs[i] = zi;
}
if tolerance < options.tolerance {
return (niter, true);
}
}
(options.max_iters, false)
}
/// Multi-threading Aberth's method
pub fn aberth_mt(coeffs: &[f64], zs: &mut Vec<Complex<f64>>, options: &Options) -> (usize, bool) {
use rayon::prelude::*;
let m_zs = zs.len();
let degree = coeffs.len() - 1; // degree, assume even
let coeffs1: Vec<_> = (0..degree)
.map(|i| coeffs[i] * (degree - i) as f64)
.collect();
let mut zsc = vec![Complex::default(); m_zs];
let mut converged = vec![false; m_zs];
for niter in 0..options.max_iters {
let mut tolerance = 0.0;
zsc.copy_from_slice(zs);
let tol_i = zs
.par_iter_mut()
.zip(converged.par_iter_mut())
.enumerate()
.filter(|(_, (_, converged))| !**converged)
.filter_map(|(i, (zi, converged))| aberth_job(coeffs, i, zi, converged, &zsc, &coeffs1))
.reduce(|| tolerance, |x, y| x.max(y));
if tolerance < tol_i {
tolerance = tol_i;
}
if tolerance < options.tolerance {
return (niter, true);
}
}
(options.max_iters, false)
}
fn aberth_job(
coeffs: &[f64],
i: usize,
zi: &mut Complex<f64>,
converged: &mut bool,
zsc: &[Complex<f64>],
coeffs1: &[f64],
) -> Option<f64> {
let p_eval = horner_eval_c(coeffs, zi);
let tol_i = p_eval.l1_norm(); // ???
if tol_i < 1e-15 {
*converged = true;
return None;
}
let mut p1_eval = horner_eval_c(coeffs1, zi);
for (_, zj) in zsc.iter().enumerate().filter(|t| t.0 != i) {
p1_eval -= p_eval / (*zi - zj);
}
*zi -= p_eval / p1_eval; // Gauss-Seidel fashion
Some(tol_i)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_horner_eval() {
let coeffs = vec![10.0, 34.0, 75.0, 94.0, 150.0, 94.0, 75.0, 34.0, 10.0];
let z = Complex::new(0.0, 0.0);
let p_eval = horner_eval_c(&coeffs, &z);
assert_eq!(p_eval.re, 10.0);
assert_eq!(p_eval.im, 0.0);
let z = Complex::new(1.0, 0.0);
let p_eval = horner_eval_c(&coeffs, &z);
assert_eq!(p_eval.re, 576.0);
assert_eq!(p_eval.im, 0.0);
}
#[test]
fn test_aberth() {
let coeffs = vec![10.0, 34.0, 75.0, 94.0, 150.0, 94.0, 75.0, 34.0, 10.0];
let mut zrs = initial_aberth(&coeffs);
let (niter, found) = aberth(&coeffs, &mut zrs, &Options::default());
assert_eq!(niter, 5);
assert!(found);
}
}