-
Notifications
You must be signed in to change notification settings - Fork 21
/
NumericalCalculations.jl
229 lines (179 loc) · 8.52 KB
/
NumericalCalculations.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# TODO deprecate testshuffle
_checkErrorCCWNumerics(ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}}, testshuffle::Bool=false) where {N_,F<:AbstractRelativeMinimize,S,T} = nothing
function _checkErrorCCWNumerics(ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}},
testshuffle::Bool=false) where {N_,F<:AbstractRelativeRoots,S,T}
#
if ccwl.zDim < ccwl.xDim && !ccwl.partial || testshuffle || ccwl.partial
error("<:AbstractRelativeRoots factors with less measurement dimensions than variable dimensions have been discontinued, easy conversion to <:AbstractRelativeMinimize is the better option.")
elseif !( ccwl.zDim >= ccwl.xDim && !ccwl.partial )
error("Unresolved numeric <:AbstractRelativeRoots solve case")
end
nothing
end
_perturbIfNecessary(fcttype::Union{F,<:Mixture{N_,F,S,T}},
len::Int=1,
perturbation::Real=1e-10 ) where {N_,F<:AbstractRelativeMinimize,S,T} = 0
#
_perturbIfNecessary(fcttype::Union{F,<:Mixture{N_,F,S,T}},
len::Int=1,
perturbation::Real=1e-10 ) where {N_,F<:AbstractRelativeRoots,S,T} = perturbation*randn(len)
#
# internal use only, and selected out from approxDeconv functions
_solveLambdaNumeric(fcttype::AbstractPrior,
objResX::Function,
residual::AbstractVector{<:Real},
u0::AbstractVector{<:Real},
islen1::Bool=false;
perturb::Real=1e-10 ) = u0
#
function _solveLambdaNumeric( fcttype::Union{F,<:Mixture{N_,F,S,T}},
objResX::Function,
residual::AbstractVector{<:Real},
u0::AbstractVector{<:Real},
islen1::Bool=false ) where {N_,F<:AbstractRelativeRoots,S,T}
#
#
r = NLsolve.nlsolve( (res, x) -> res .= objResX(x), u0, inplace=true)
#
return r.zero
end
function _solveLambdaNumeric( fcttype::Union{F,<:Mixture{N_,F,S,T}},
objResX::Function,
residual::AbstractVector{<:Real},
u0::AbstractVector{<:Real},
islen1::Bool=false ) where {N_,F<:AbstractRelativeMinimize,S,T}
# retries::Int=3 )
#
# wrt #467 allow residual to be standardize for Roots and Minimize and Parametric cases.
r = if islen1
Optim.optimize((x) -> (residual .= objResX(x); sum(residual.^2)), u0, Optim.BFGS() )
else
Optim.optimize((x) -> (residual .= objResX(x); sum(residual.^2)), u0)
end
#
return r.minimizer
end
## ================================================================================================
## Heavy dispatch for all AbstractFactor / Mixture cases below
## ================================================================================================
# internal function to dispatch view on either vector or matrix, rows are dims and samples are columns
_viewdim1or2(other, ind1, ind2) = other
_viewdim1or2(arr::AbstractVector, ind1, ind2) = view(arr, ind2)
_viewdim1or2(arr::AbstractMatrix, ind1, ind2) = view(arr, ind1, ind2)
function _buildCalcFactorMixture( ccwl::CommonConvWrapper,
_fmd_,
smpid,
measurement_,
varParams )
#
CalcFactor( ccwl.usrfnc!, _fmd_, smpid,
length(measurement_), measurement_, varParams)
end
function _buildCalcFactorMixture( ccwl::CommonConvWrapper{Mixture{N_,F,S,T}},
_fmd_,
smpid,
measurement_,
varParams ) where {N_,F <: FunctorInferenceType,S,T}
#
# just a passthrough similar to pre-v0.20
CalcFactor( ccwl.usrfnc!.mechanics, _fmd_, smpid,
length(measurement_), measurement_, varParams)
end
"""
$SIGNATURES
Internal function to build lambda pre-objective function for finding factor residuals.
Notes
- Unless passed in as separate arguments, this assumes already valid in `cpt_`:
- `cpt_.p`
- `cpt_.activehypo`
- `cpt_.factormetadata`
- `ccwl.params`
- `ccwl.measurement`
DevNotes
- TODO refactor relationship and common fields between (CCW, FMd, CPT, CalcFactor)
"""
function _buildCalcFactorLambdaSample(ccwl::CommonConvWrapper,
smpid::Int,
cpt_::ConvPerThread = ccwl.cpt[Threads.threadid()],
target::AbstractVector = view(ccwl.params[ccwl.varidx], cpt_.p, smpid),
measurement_ = ccwl.measurement,
fmd_::FactorMetadata = cpt_.factormetadata )
#
# build a view to the decision variable memory
varParams = view(ccwl.params, cpt_.activehypo)
# prepare fmd according to hypo selection
# FIXME must refactor (memory waste)
_fmd_ = FactorMetadata( view(fmd_.fullvariables, cpt_.activehypo),
view(fmd_.variablelist, cpt_.activehypo),
varParams, # view(fmd_.arrRef, cpt_.activehypo),
fmd_.solvefor,
fmd_.cachedata )
#
# get the operational CalcFactor object
cf = _buildCalcFactorMixture(ccwl, _fmd_, smpid, measurement_, varParams)
# new dev work on CalcFactor
# cf = CalcFactor(ccwl.usrfnc!, _fmd_, smpid,
# length(measurement_), measurement_, varParams)
#
# reset the residual vector
fill!(cpt_.res, 0.0) # Roots->xDim | Minimize->zDim
# build static lambda
unrollHypo! = ()->cf( (_viewdim1or2.(measurement_, :, smpid))..., (view.(varParams, :, smpid))... )
return unrollHypo!, target
end
"""
$(SIGNATURES)
Solve free variable x by root finding residual function `fgr.usrfnc(res, x)`. This is the
penultimate step before calling numerical operations to move actual estimates, which is
done by an internally created lambda function.
ccw.X must be set to memory ref the param[varidx] being solved, at creation of ccw
Notes
- Assumes `cpt_.p` is already set to desired X decision variable dimensions and size.
- Assumes only `ccw.particleidx` will be solved for
- small random (off-manifold) perturbation used to prevent trivial solver cases, div by 0 etc.
- perturb is necessary for NLsolve cases, and smaller than 1e-10 will result in test failure
- Also incorporates the active hypo lookup
DevNotes
- TODO testshuffle is now obsolete, should be removed
- TODO perhaps consolidate perturbation with inflation or nullhypo
"""
function _solveCCWNumeric!( ccwl::Union{CommonConvWrapper{F},
CommonConvWrapper{Mixture{N_,F,S,T}}};
perturb::Real=1e-10,
testshuffle::Bool=false ) where {N_,F<:AbstractRelative,S,T}
#
# FIXME, move this check higher and out of smpid loop
_checkErrorCCWNumerics(ccwl, testshuffle)
#
thrid = Threads.threadid()
cpt_ = ccwl.cpt[thrid]
smpid = cpt_.particleidx
# cannot Nelder-Mead on 1dim, partial can be 1dim or more but being conservative.
islen1 = length(cpt_.p) == 1 || ccwl.partial
# islen1 = length(cpt_.X[:, smpid]) == 1 || ccwl.partial
# build the pre-objective function for this sample's hypothesis selection
unrollHypo!, target = _buildCalcFactorLambdaSample(ccwl, smpid, cpt_)
# broadcast updates original view memory location
## using CalcFactor legacy path inside (::CalcFactor)
_hypoObj = (x) -> (target.=x; unrollHypo!())
# TODO small off-manifold perturbation is a numerical workaround only, make on-manifold requires RoME.jl #244
# use all element dimensions : ==> 1:ccwl.xDim
target .+= _perturbIfNecessary(getFactorType(ccwl), length(target), perturb)
# do the parameter search over defined decision variables using Minimization
retval = _solveLambdaNumeric(getFactorType(ccwl), _hypoObj, cpt_.res, cpt_.X[cpt_.p, smpid], islen1 )
# Check for NaNs
if sum(isnan.(retval)) != 0
@error "$(ccwl.usrfnc!), ccw.thrid_=$(thrid), got NaN, smpid = $(smpid), r=$(retval)\n"
return nothing
end
# insert result back at the correct variable element location
cpt_.X[cpt_.p,smpid] .= retval
nothing
end
# brainstorming
# should only be calling a new arg list according to activehypo at start of particle
# Try calling an existing lambda
# sensitive to which hypo of course , see #1024
# need to shuffle content inside .cpt.fmd as well as .params accordingly
#
#