From e1d3af610ee7a78f9371a9d96f2d88ed2dd31945 Mon Sep 17 00:00:00 2001 From: dehann Date: Wed, 8 Nov 2023 05:44:19 -0800 Subject: [PATCH 1/3] DERelative residual more on-manifold --- ext/IncrInfrDiffEqFactorExt.jl | 47 +++++++++++++++++++--------------- src/entities/ExtFactors.jl | 2 +- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/ext/IncrInfrDiffEqFactorExt.jl b/ext/IncrInfrDiffEqFactorExt.jl index af5b19cf..b996baff 100644 --- a/ext/IncrInfrDiffEqFactorExt.jl +++ b/ext/IncrInfrDiffEqFactorExt.jl @@ -17,7 +17,7 @@ using DocStringExtensions export DERelative -import Manifolds: allocate +import Manifolds: allocate, compose, hat, Identity, vee, log getManifold(de::DERelative{T}) where {T} = getManifold(de.domain) @@ -67,7 +67,7 @@ function DERelative( # backward time problem bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt) # build the IIF recognizable object - return DERelative(domain, fproblem, bproblem, datatuple, getSample) + return DERelative(domain, fproblem, bproblem, datatuple) #, getSample) end function DERelative( @@ -88,11 +88,11 @@ function DERelative( domain, f, data; - dt = dt, - state0 = state0, - state1 = state1, - tspan = tspan, - problemType = problemType, + dt, + state0, + state1, + tspan, + problemType, ) end # @@ -162,7 +162,8 @@ end function (cf::CalcFactor{<:DERelative})(measurement, X...) # meas1 = measurement[1] - diffOp = measurement[2] + M = measurement[2] + # diffOp = measurement[2] oderel = cf.factor @@ -193,12 +194,15 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...) ## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon! # @show cf._sampleIdx, solveforIdx, meas1 - #FIXME - res = zeros(size(X[2], 1)) - for i = 1:size(X[2], 1) - # diffop( reference?, test? ) <===> ΔX = test \ reference - res[i] = diffOp[i](X[solveforIdx][i], meas1[i]) - end + res_ = compose(M, inv(M, X[solveforIdx]), meas1) + res = vee(M, Identity(M), log(M, Identity(M), res_)) + + # #FIXME 0 + # res = zeros(size(X[2], 1)) + # for i = 1:size(X[2], 1) + # # diffop( reference?, test? ) <===> ΔX = test \ reference + # res[i] = diffOp[i](X[solveforIdx][i], meas1[i]) + # end return res end @@ -260,23 +264,25 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int # pick forward or backward direction # set boundary condition - u0pts = if cf.solvefor == 1 + u0pts, M = if cf.solvefor == 1 # backward direction prob = oder.backwardProblem + M_ = getManifold(getVariableType(cf.fullvariables[1])) addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks( - convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))), + convert(Tuple, M_), ) # getBelief(cf.fullvariables[2]) |> getPoints - cf._legacyParams[2] + cf._legacyParams[2], M_ else # forward backward prob = oder.forwardProblem + M_ = getManifold(getVariableType(cf.fullvariables[2])) # buffer manifold operations for use during factor evaluation addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks( - convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))), + convert(Tuple, M_), ) # getBelief(cf.fullvariables[1]) |> getPoints - cf._legacyParams[1] + cf._legacyParams[1], M_ end # solve likely elements @@ -287,7 +293,8 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int # _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...) end - return map(x -> (x, diffOp), meas) + # return meas, M + return map(x -> (x, M), meas) end # getDimension(oderel.domain) diff --git a/src/entities/ExtFactors.jl b/src/entities/ExtFactors.jl index 36dce47e..f6d916e4 100644 --- a/src/entities/ExtFactors.jl +++ b/src/entities/ExtFactors.jl @@ -25,5 +25,5 @@ struct DERelative{T <: InferenceVariable, P, D} <: AbstractManifoldMinimize # Ab backwardProblem::P """ second element of this data tuple is additional variables that will be passed down as a parameter """ data::D - specialSampler::Function + # specialSampler::Function end \ No newline at end of file From 3d91fbb011378c949b606ab5b031a57535bbb2fd Mon Sep 17 00:00:00 2001 From: dehann Date: Wed, 8 Nov 2023 08:30:07 -0800 Subject: [PATCH 2/3] further fixes cleanup DERelative for imu --- ext/IncrInfrDiffEqFactorExt.jl | 83 +++++++++++++++++----------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/ext/IncrInfrDiffEqFactorExt.jl b/ext/IncrInfrDiffEqFactorExt.jl index b996baff..60ac711d 100644 --- a/ext/IncrInfrDiffEqFactorExt.jl +++ b/ext/IncrInfrDiffEqFactorExt.jl @@ -22,6 +22,25 @@ import Manifolds: allocate, compose, hat, Identity, vee, log getManifold(de::DERelative{T}) where {T} = getManifold(de.domain) + +function Base.show( + io::IO, + ::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}} +) where {T,O} + println(io, " DERelative{") + println(io, " ", T) + println(io, " ", O.name.name) + println(io, " }") + nothing +end + +Base.show( + io::IO, + ::MIME"text/plain", + der::DERelative +) = show(io, der) + + """ $SIGNATURES @@ -31,7 +50,9 @@ DevNotes - TODO does not yet incorporate Xi.nanosecond field. - TODO does not handle timezone crossing properly yet. """ -function _calcTimespan(Xi::AbstractVector{<:DFGVariable}) +function _calcTimespan( + Xi::AbstractVector{<:DFGVariable} +) # tsmps = getTimestamp.(Xi[1:2]) .|> DateTime .|> datetime2unix # toffs = (tsmps .- tsmps[1]) .|> x-> elemType(x.value*1e-3) @@ -50,8 +71,8 @@ function DERelative( f::Function, data = () -> (); dt::Real = 1, - state0::AbstractVector{<:Real} = zeros(getDimension(domain)), - state1::AbstractVector{<:Real} = zeros(getDimension(domain)), + state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)), + state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)), tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi), problemType = DiscreteProblem, ) @@ -78,8 +99,8 @@ function DERelative( data = () -> (); Xi::AbstractArray{<:DFGVariable} = getVariable.(dfg, labels), dt::Real = 1, - state0::AbstractVector{<:Real} = zeros(getDimension(domain)), - state1::AbstractVector{<:Real} = zeros(getDimension(domain)), + state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)), + state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)), tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi), problemType = DiscreteProblem, ) @@ -99,7 +120,12 @@ end # # n-ary factor: Xtra splat are variable points (X3::Matrix, X4::Matrix,...) -function _solveFactorODE!(measArr, prob, u0pts, Xtra...) +function _solveFactorODE!( + measArr, + prob, + u0pts, + Xtra... +) # happens when more variables (n-ary) must be included in DE solve for (xid, xtra) in enumerate(Xtra) # update the data register before ODE solver calls the function @@ -159,22 +185,21 @@ end # NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.__` fields; OBSOLETE -function (cf::CalcFactor{<:DERelative})(measurement, X...) +function (cf::CalcFactor{<:DERelative})( + measurement, + X... +) # + # numerical measurement values meas1 = measurement[1] + # work on-manifold via sampleFactor piggy back of particular manifold definition M = measurement[2] - # diffOp = measurement[2] - + # lazy factor pointer oderel = cf.factor - - # work on-manifold - # diffOp = meas[2] - # if backwardSolve else forward - # check direction - solveforIdx = cf.solvefor - + + # if backwardSolve else forward if solveforIdx > 2 # need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable) solveforIdx = 2 @@ -190,19 +215,10 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...) end # find the difference between measured and predicted. - ## assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`) - ## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon! - # @show cf._sampleIdx, solveforIdx, meas1 - + # assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`) res_ = compose(M, inv(M, X[solveforIdx]), meas1) res = vee(M, Identity(M), log(M, Identity(M), res_)) - # #FIXME 0 - # res = zeros(size(X[2], 1)) - # for i = 1:size(X[2], 1) - # # diffop( reference?, test? ) <===> ΔX = test \ reference - # res[i] = diffOp[i](X[solveforIdx][i], meas1[i]) - # end return res end @@ -300,19 +316,4 @@ end -function Base.show(io::IO, ::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}) where {T,O} - println(io, " DERelative{") - println(io, " ", T) - println(io, " ", O.name.name) - println(io, " }") - nothing -end - -Base.show(io::IO, ::MIME"text/plain", der::DERelative) = show(io, der) - -## the function -# ode.problem.f.f - -# - end # module \ No newline at end of file From cfed191e8dd3f5f4b88efe1ebdf7ec34a4ac81c4 Mon Sep 17 00:00:00 2001 From: dehann Date: Wed, 8 Nov 2023 20:25:15 -0800 Subject: [PATCH 3/3] cleanup --- ext/IncrInfrDiffEqFactorExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/IncrInfrDiffEqFactorExt.jl b/ext/IncrInfrDiffEqFactorExt.jl index 60ac711d..69d64a84 100644 --- a/ext/IncrInfrDiffEqFactorExt.jl +++ b/ext/IncrInfrDiffEqFactorExt.jl @@ -74,7 +74,7 @@ function DERelative( state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)), state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)), tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi), - problemType = DiscreteProblem, + problemType = ODEProblem, # DiscreteProblem, ) # datatuple = if 2 < length(Xi) @@ -84,7 +84,7 @@ function DERelative( data end # forward time problem - fproblem = problemType(f, state0, tspan, datatuple; dt = dt) + fproblem = problemType(f, state0, tspan, datatuple; dt) # backward time problem bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt) # build the IIF recognizable object