Skip to content

Commit

Permalink
Merge pull request #1804 from JuliaRobotics/master
Browse files Browse the repository at this point in the history
v0.35.1-rc1
  • Loading branch information
dehann authored Nov 14, 2023
2 parents 8ef0459 + dbdeab3 commit ed9ead0
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 60 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The list below highlights breaking changes according to normal semver workflow -
- Further bug fixes for transition to `StaticArrays` value stores and computes, including `Position{N}` (#1779, #1776).
- Restore `DifferentialEquation.jl` factor `DERelative` functionality and tests that were suppressed in a previous upgrade (#1774, #1777).
- Restore previously suppressed tests (#1781, #1721, #1780)
- Improve DERelative factor on-manifold operations (#1775, #1802, #1803).

# Changes in v0.34

Expand Down
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "IncrementalInference"
uuid = "904591bb-b899-562f-9e6f-b8df64c7d480"
keywords = ["MM-iSAMv2", "Bayes tree", "junction tree", "Bayes network", "variable elimination", "graphical models", "SLAM", "inference", "sum-product", "belief-propagation"]
desc = "Implements the Multimodal-iSAMv2 algorithm."
version = "0.35.0"
version = "0.35.1"

[deps]
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
Expand Down Expand Up @@ -91,6 +91,7 @@ RecursiveArrayTools = "2.31.1"
Reexport = "1"
SparseDiffTools = "2"
StaticArrays = "1"
Statistics = "1"
StatsBase = "0.32, 0.33, 0.34"
StructTypes = "1"
TensorCast = "0.3.3, 0.4"
Expand Down
124 changes: 73 additions & 51 deletions ext/IncrInfrDiffEqFactorExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module IncrInfrDiffEqFactorExt

@info "IncrementalInference.jl is loading extensions related to DifferentialEquations.jl"

import Base: show

using DifferentialEquations
import DifferentialEquations: solve

Expand All @@ -15,10 +17,30 @@ using DocStringExtensions

export DERelative

import Manifolds: allocate, compose, hat, Identity, vee, log


getManifold(de::DERelative{T}) where {T} = getManifold(de.domain)


function Base.show(
io::IO,
::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}
) where {T,O}
println(io, " DERelative{")
println(io, " ", T)
println(io, " ", O.name.name)
println(io, " }")
nothing
end

Base.show(
io::IO,
::MIME"text/plain",
der::DERelative
) = show(io, der)


"""
$SIGNATURES
Expand All @@ -28,7 +50,9 @@ DevNotes
- TODO does not yet incorporate Xi.nanosecond field.
- TODO does not handle timezone crossing properly yet.
"""
function _calcTimespan(Xi::AbstractVector{<:DFGVariable})
function _calcTimespan(
Xi::AbstractVector{<:DFGVariable}
)
#
tsmps = getTimestamp.(Xi[1:2]) .|> DateTime .|> datetime2unix
# toffs = (tsmps .- tsmps[1]) .|> x-> elemType(x.value*1e-3)
Expand All @@ -47,10 +71,10 @@ function DERelative(
f::Function,
data = () -> ();
dt::Real = 1,
state0::AbstractVector{<:Real} = zeros(getDimension(domain)),
state1::AbstractVector{<:Real} = zeros(getDimension(domain)),
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
problemType = DiscreteProblem,
problemType = ODEProblem, # DiscreteProblem,
)
#
datatuple = if 2 < length(Xi)
Expand All @@ -60,11 +84,11 @@ function DERelative(
data
end
# forward time problem
fproblem = problemType(f, state0, tspan, datatuple; dt = dt)
fproblem = problemType(f, state0, tspan, datatuple; dt)
# backward time problem
bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt)
# build the IIF recognizable object
return DERelative(domain, fproblem, bproblem, datatuple, getSample)
return DERelative(domain, fproblem, bproblem, datatuple) #, getSample)
end

function DERelative(
Expand All @@ -75,8 +99,8 @@ function DERelative(
data = () -> ();
Xi::AbstractArray{<:DFGVariable} = getVariable.(dfg, labels),
dt::Real = 1,
state0::AbstractVector{<:Real} = zeros(getDimension(domain)),
state1::AbstractVector{<:Real} = zeros(getDimension(domain)),
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
problemType = DiscreteProblem,
)
Expand All @@ -85,26 +109,32 @@ function DERelative(
domain,
f,
data;
dt = dt,
state0 = state0,
state1 = state1,
tspan = tspan,
problemType = problemType,
dt,
state0,
state1,
tspan,
problemType,
)
end
#
#

# n-ary factor: Xtra splat are variable points (X3::Matrix, X4::Matrix,...)
function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
function _solveFactorODE!(
measArr,
prob,
u0pts,
Xtra...
)
# happens when more variables (n-ary) must be included in DE solve
for (xid, xtra) in enumerate(Xtra)
# update the data register before ODE solver calls the function
prob.p[xid + 1][:] = xtra[:]
prob.p[xid + 1][:] = xtra[:] # FIXME, unlikely to work with ArrayPartition, maybe use MArray and `.=`
end

# set the initial condition
prob.u0[:] = u0pts[:]
prob.u0 .= u0pts

sol = DifferentialEquations.solve(prob)

# extract solution from solved ode
Expand Down Expand Up @@ -155,21 +185,21 @@ end


# NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.__` fields; OBSOLETE
function (cf::CalcFactor{<:DERelative})(measurement, X...)
function (cf::CalcFactor{<:DERelative})(
measurement,
X...
)
#
# numerical measurement values
meas1 = measurement[1]
diffOp = measurement[2]

# work on-manifold via sampleFactor piggy back of particular manifold definition
M = measurement[2]
# lazy factor pointer
oderel = cf.factor

# work on-manifold
# diffOp = meas[2]
# if backwardSolve else forward

# check direction

solveforIdx = cf.solvefor


# if backwardSolve else forward
if solveforIdx > 2
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
solveforIdx = 2
Expand All @@ -185,16 +215,10 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
end

# find the difference between measured and predicted.
## assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
# @show cf._sampleIdx, solveforIdx, meas1

#FIXME
res = zeros(size(X[2], 1))
for i = 1:size(X[2], 1)
# diffop( reference?, test? ) <===> ΔX = test \ reference
res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
end
# assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
res_ = compose(M, inv(M, X[solveforIdx]), meas1)
res = vee(M, Identity(M), log(M, Identity(M), res_))

return res
end

Expand Down Expand Up @@ -249,28 +273,32 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
oder = cf.factor

# how many trajectories to propagate?
# @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
#
v2T = getVariableType(cf.fullvariables[2])
meas = [allocate(getPointIdentity(v2T)) for _ = 1:N]
# meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]

# pick forward or backward direction
# set boundary condition
u0pts = if cf.solvefor == 1
u0pts, M = if cf.solvefor == 1
# backward direction
prob = oder.backwardProblem
M_ = getManifold(getVariableType(cf.fullvariables[1]))
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
convert(Tuple, M_),
)
# getBelief(cf.fullvariables[2]) |> getPoints
cf._legacyParams[2]
cf._legacyParams[2], M_
else
# forward backward
prob = oder.forwardProblem
M_ = getManifold(getVariableType(cf.fullvariables[2]))
# buffer manifold operations for use during factor evaluation
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
convert(Tuple, M_),
)
# getBelief(cf.fullvariables[1]) |> getPoints
cf._legacyParams[1]
cf._legacyParams[1], M_
end

# solve likely elements
Expand All @@ -281,17 +309,11 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
# _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
end

return map(x -> (x, diffOp), meas)
# return meas, M
return map(x -> (x, M), meas)
end
# getDimension(oderel.domain)





## the function
# ode.problem.f.f

#

end # module
9 changes: 9 additions & 0 deletions src/ExportAPI.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# the IncrementalInference API


# reexport
export ℝ, AbstractManifold
export Identity, hat , vee, ArrayPartition, exp!, exp, log!, log
# common groups -- preferred defaults at this time.
export TranslationGroup, RealCircleGroup
# common non-groups -- TODO still teething problems to sort out in IIF v0.25-v0.26.
export Euclidean, Circle

# DFG SpecialDefinitions
export AbstractDFG,
getSolverParams,
Expand Down
7 changes: 0 additions & 7 deletions src/IncrementalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ using FiniteDifferences

using OrderedCollections: OrderedDict

export ℝ, AbstractManifold
# export ProductRepr
# common groups -- preferred defaults at this time.
export TranslationGroup, RealCircleGroup
# common non-groups -- TODO still teething problems to sort out in IIF v0.25-v0.26.
export Euclidean, Circle

import Optim

using Dates,
Expand Down
2 changes: 1 addition & 1 deletion src/entities/ExtFactors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ struct DERelative{T <: InferenceVariable, P, D} <: AbstractManifoldMinimize # Ab
backwardProblem::P
""" second element of this data tuple is additional variables that will be passed down as a parameter """
data::D
specialSampler::Function
# specialSampler::Function
end

0 comments on commit ed9ead0

Please sign in to comment.