diff --git a/NEWS.md b/NEWS.md index 764d10b9..b4c1f585 100644 --- a/NEWS.md +++ b/NEWS.md @@ -23,6 +23,7 @@ The list below highlights breaking changes according to normal semver workflow - - Further bug fixes for transition to `StaticArrays` value stores and computes, including `Position{N}` (#1779, #1776). - Restore `DifferentialEquation.jl` factor `DERelative` functionality and tests that were suppressed in a previous upgrade (#1774, #1777). - Restore previously suppressed tests (#1781, #1721, #1780) +- Improve DERelative factor on-manifold operations (#1775, #1802, #1803). # Changes in v0.34 diff --git a/Project.toml b/Project.toml index 1e150c45..8d5b52fb 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "IncrementalInference" uuid = "904591bb-b899-562f-9e6f-b8df64c7d480" keywords = ["MM-iSAMv2", "Bayes tree", "junction tree", "Bayes network", "variable elimination", "graphical models", "SLAM", "inference", "sum-product", "belief-propagation"] desc = "Implements the Multimodal-iSAMv2 algorithm." -version = "0.35.0" +version = "0.35.1" [deps] ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89" @@ -91,6 +91,7 @@ RecursiveArrayTools = "2.31.1" Reexport = "1" SparseDiffTools = "2" StaticArrays = "1" +Statistics = "1" StatsBase = "0.32, 0.33, 0.34" StructTypes = "1" TensorCast = "0.3.3, 0.4" diff --git a/ext/IncrInfrDiffEqFactorExt.jl b/ext/IncrInfrDiffEqFactorExt.jl index d9be747d..69d64a84 100644 --- a/ext/IncrInfrDiffEqFactorExt.jl +++ b/ext/IncrInfrDiffEqFactorExt.jl @@ -2,6 +2,8 @@ module IncrInfrDiffEqFactorExt @info "IncrementalInference.jl is loading extensions related to DifferentialEquations.jl" +import Base: show + using DifferentialEquations import DifferentialEquations: solve @@ -15,10 +17,30 @@ using DocStringExtensions export DERelative +import Manifolds: allocate, compose, hat, Identity, vee, log getManifold(de::DERelative{T}) where {T} = getManifold(de.domain) + +function Base.show( + io::IO, + ::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}} +) where {T,O} + println(io, " DERelative{") + println(io, " ", T) + println(io, " ", O.name.name) + println(io, " }") + nothing +end + +Base.show( + io::IO, + ::MIME"text/plain", + der::DERelative +) = show(io, der) + + """ $SIGNATURES @@ -28,7 +50,9 @@ DevNotes - TODO does not yet incorporate Xi.nanosecond field. - TODO does not handle timezone crossing properly yet. """ -function _calcTimespan(Xi::AbstractVector{<:DFGVariable}) +function _calcTimespan( + Xi::AbstractVector{<:DFGVariable} +) # tsmps = getTimestamp.(Xi[1:2]) .|> DateTime .|> datetime2unix # toffs = (tsmps .- tsmps[1]) .|> x-> elemType(x.value*1e-3) @@ -47,10 +71,10 @@ function DERelative( f::Function, data = () -> (); dt::Real = 1, - state0::AbstractVector{<:Real} = zeros(getDimension(domain)), - state1::AbstractVector{<:Real} = zeros(getDimension(domain)), + state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)), + state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)), tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi), - problemType = DiscreteProblem, + problemType = ODEProblem, # DiscreteProblem, ) # datatuple = if 2 < length(Xi) @@ -60,11 +84,11 @@ function DERelative( data end # forward time problem - fproblem = problemType(f, state0, tspan, datatuple; dt = dt) + fproblem = problemType(f, state0, tspan, datatuple; dt) # backward time problem bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt) # build the IIF recognizable object - return DERelative(domain, fproblem, bproblem, datatuple, getSample) + return DERelative(domain, fproblem, bproblem, datatuple) #, getSample) end function DERelative( @@ -75,8 +99,8 @@ function DERelative( data = () -> (); Xi::AbstractArray{<:DFGVariable} = getVariable.(dfg, labels), dt::Real = 1, - state0::AbstractVector{<:Real} = zeros(getDimension(domain)), - state1::AbstractVector{<:Real} = zeros(getDimension(domain)), + state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)), + state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)), tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi), problemType = DiscreteProblem, ) @@ -85,26 +109,32 @@ function DERelative( domain, f, data; - dt = dt, - state0 = state0, - state1 = state1, - tspan = tspan, - problemType = problemType, + dt, + state0, + state1, + tspan, + problemType, ) end # # # n-ary factor: Xtra splat are variable points (X3::Matrix, X4::Matrix,...) -function _solveFactorODE!(measArr, prob, u0pts, Xtra...) +function _solveFactorODE!( + measArr, + prob, + u0pts, + Xtra... +) # happens when more variables (n-ary) must be included in DE solve for (xid, xtra) in enumerate(Xtra) # update the data register before ODE solver calls the function - prob.p[xid + 1][:] = xtra[:] + prob.p[xid + 1][:] = xtra[:] # FIXME, unlikely to work with ArrayPartition, maybe use MArray and `.=` end # set the initial condition - prob.u0[:] = u0pts[:] + prob.u0 .= u0pts + sol = DifferentialEquations.solve(prob) # extract solution from solved ode @@ -155,21 +185,21 @@ end # NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.__` fields; OBSOLETE -function (cf::CalcFactor{<:DERelative})(measurement, X...) +function (cf::CalcFactor{<:DERelative})( + measurement, + X... +) # + # numerical measurement values meas1 = measurement[1] - diffOp = measurement[2] - + # work on-manifold via sampleFactor piggy back of particular manifold definition + M = measurement[2] + # lazy factor pointer oderel = cf.factor - - # work on-manifold - # diffOp = meas[2] - # if backwardSolve else forward - # check direction - solveforIdx = cf.solvefor - + + # if backwardSolve else forward if solveforIdx > 2 # need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable) solveforIdx = 2 @@ -185,16 +215,10 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...) end # find the difference between measured and predicted. - ## assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`) - ## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon! - # @show cf._sampleIdx, solveforIdx, meas1 - - #FIXME - res = zeros(size(X[2], 1)) - for i = 1:size(X[2], 1) - # diffop( reference?, test? ) <===> ΔX = test \ reference - res[i] = diffOp[i](X[solveforIdx][i], meas1[i]) - end + # assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`) + res_ = compose(M, inv(M, X[solveforIdx]), meas1) + res = vee(M, Identity(M), log(M, Identity(M), res_)) + return res end @@ -249,28 +273,32 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int oder = cf.factor # how many trajectories to propagate? - # @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2]) - meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N] + # + v2T = getVariableType(cf.fullvariables[2]) + meas = [allocate(getPointIdentity(v2T)) for _ = 1:N] + # meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N] # pick forward or backward direction # set boundary condition - u0pts = if cf.solvefor == 1 + u0pts, M = if cf.solvefor == 1 # backward direction prob = oder.backwardProblem + M_ = getManifold(getVariableType(cf.fullvariables[1])) addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks( - convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))), + convert(Tuple, M_), ) # getBelief(cf.fullvariables[2]) |> getPoints - cf._legacyParams[2] + cf._legacyParams[2], M_ else # forward backward prob = oder.forwardProblem + M_ = getManifold(getVariableType(cf.fullvariables[2])) # buffer manifold operations for use during factor evaluation addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks( - convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))), + convert(Tuple, M_), ) # getBelief(cf.fullvariables[1]) |> getPoints - cf._legacyParams[1] + cf._legacyParams[1], M_ end # solve likely elements @@ -281,17 +309,11 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int # _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...) end - return map(x -> (x, diffOp), meas) + # return meas, M + return map(x -> (x, M), meas) end # getDimension(oderel.domain) - - -## the function -# ode.problem.f.f - -# - end # module \ No newline at end of file diff --git a/src/ExportAPI.jl b/src/ExportAPI.jl index 429d2e17..a415dd4d 100644 --- a/src/ExportAPI.jl +++ b/src/ExportAPI.jl @@ -1,5 +1,14 @@ # the IncrementalInference API + +# reexport +export ℝ, AbstractManifold +export Identity, hat , vee, ArrayPartition, exp!, exp, log!, log +# common groups -- preferred defaults at this time. +export TranslationGroup, RealCircleGroup +# common non-groups -- TODO still teething problems to sort out in IIF v0.25-v0.26. +export Euclidean, Circle + # DFG SpecialDefinitions export AbstractDFG, getSolverParams, diff --git a/src/IncrementalInference.jl b/src/IncrementalInference.jl index 65fff3a0..22243543 100644 --- a/src/IncrementalInference.jl +++ b/src/IncrementalInference.jl @@ -19,13 +19,6 @@ using FiniteDifferences using OrderedCollections: OrderedDict -export ℝ, AbstractManifold -# export ProductRepr -# common groups -- preferred defaults at this time. -export TranslationGroup, RealCircleGroup -# common non-groups -- TODO still teething problems to sort out in IIF v0.25-v0.26. -export Euclidean, Circle - import Optim using Dates, diff --git a/src/entities/ExtFactors.jl b/src/entities/ExtFactors.jl index 36dce47e..f6d916e4 100644 --- a/src/entities/ExtFactors.jl +++ b/src/entities/ExtFactors.jl @@ -25,5 +25,5 @@ struct DERelative{T <: InferenceVariable, P, D} <: AbstractManifoldMinimize # Ab backwardProblem::P """ second element of this data tuple is additional variables that will be passed down as a parameter """ data::D - specialSampler::Function + # specialSampler::Function end \ No newline at end of file