Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup more organized on DERelative #1803

Merged
merged 3 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 64 additions & 56 deletions ext/IncrInfrDiffEqFactorExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,30 @@ using DocStringExtensions

export DERelative

import Manifolds: allocate
import Manifolds: allocate, compose, hat, Identity, vee, log


getManifold(de::DERelative{T}) where {T} = getManifold(de.domain)


function Base.show(
io::IO,
::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}
) where {T,O}
println(io, " DERelative{")
println(io, " ", T)
println(io, " ", O.name.name)
println(io, " }")
nothing
end

Base.show(
io::IO,
::MIME"text/plain",
der::DERelative
) = show(io, der)


"""
$SIGNATURES

Expand All @@ -31,7 +50,9 @@ DevNotes
- TODO does not yet incorporate Xi.nanosecond field.
- TODO does not handle timezone crossing properly yet.
"""
function _calcTimespan(Xi::AbstractVector{<:DFGVariable})
function _calcTimespan(
Xi::AbstractVector{<:DFGVariable}
)
#
tsmps = getTimestamp.(Xi[1:2]) .|> DateTime .|> datetime2unix
# toffs = (tsmps .- tsmps[1]) .|> x-> elemType(x.value*1e-3)
Expand All @@ -50,10 +71,10 @@ function DERelative(
f::Function,
data = () -> ();
dt::Real = 1,
state0::AbstractVector{<:Real} = zeros(getDimension(domain)),
state1::AbstractVector{<:Real} = zeros(getDimension(domain)),
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
problemType = DiscreteProblem,
problemType = ODEProblem, # DiscreteProblem,
)
#
datatuple = if 2 < length(Xi)
Expand All @@ -63,11 +84,11 @@ function DERelative(
data
end
# forward time problem
fproblem = problemType(f, state0, tspan, datatuple; dt = dt)
fproblem = problemType(f, state0, tspan, datatuple; dt)
# backward time problem
bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt)
# build the IIF recognizable object
return DERelative(domain, fproblem, bproblem, datatuple, getSample)
return DERelative(domain, fproblem, bproblem, datatuple) #, getSample)
end

function DERelative(
Expand All @@ -78,8 +99,8 @@ function DERelative(
data = () -> ();
Xi::AbstractArray{<:DFGVariable} = getVariable.(dfg, labels),
dt::Real = 1,
state0::AbstractVector{<:Real} = zeros(getDimension(domain)),
state1::AbstractVector{<:Real} = zeros(getDimension(domain)),
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
problemType = DiscreteProblem,
)
Expand All @@ -88,18 +109,23 @@ function DERelative(
domain,
f,
data;
dt = dt,
state0 = state0,
state1 = state1,
tspan = tspan,
problemType = problemType,
dt,
state0,
state1,
tspan,
problemType,
)
end
#
#

# n-ary factor: Xtra splat are variable points (X3::Matrix, X4::Matrix,...)
function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
function _solveFactorODE!(
measArr,
prob,
u0pts,
Xtra...
)
# happens when more variables (n-ary) must be included in DE solve
for (xid, xtra) in enumerate(Xtra)
# update the data register before ODE solver calls the function
Expand Down Expand Up @@ -159,21 +185,21 @@ end


# NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.__` fields; OBSOLETE
function (cf::CalcFactor{<:DERelative})(measurement, X...)
function (cf::CalcFactor{<:DERelative})(
measurement,
X...
)
#
# numerical measurement values
meas1 = measurement[1]
diffOp = measurement[2]

# work on-manifold via sampleFactor piggy back of particular manifold definition
M = measurement[2]
# lazy factor pointer
oderel = cf.factor

# work on-manifold
# diffOp = meas[2]
# if backwardSolve else forward

# check direction

solveforIdx = cf.solvefor


# if backwardSolve else forward
if solveforIdx > 2
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
solveforIdx = 2
Expand All @@ -189,16 +215,10 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
end

# find the difference between measured and predicted.
## assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
# @show cf._sampleIdx, solveforIdx, meas1

#FIXME
res = zeros(size(X[2], 1))
for i = 1:size(X[2], 1)
# diffop( reference?, test? ) <===> ΔX = test \ reference
res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
end
# assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
res_ = compose(M, inv(M, X[solveforIdx]), meas1)
res = vee(M, Identity(M), log(M, Identity(M), res_))

return res
end

Expand Down Expand Up @@ -260,23 +280,25 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int

# pick forward or backward direction
# set boundary condition
u0pts = if cf.solvefor == 1
u0pts, M = if cf.solvefor == 1
# backward direction
prob = oder.backwardProblem
M_ = getManifold(getVariableType(cf.fullvariables[1]))
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
convert(Tuple, M_),
)
# getBelief(cf.fullvariables[2]) |> getPoints
cf._legacyParams[2]
cf._legacyParams[2], M_
else
# forward backward
prob = oder.forwardProblem
M_ = getManifold(getVariableType(cf.fullvariables[2]))
# buffer manifold operations for use during factor evaluation
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
convert(Tuple, M_),
)
# getBelief(cf.fullvariables[1]) |> getPoints
cf._legacyParams[1]
cf._legacyParams[1], M_
end

# solve likely elements
Expand All @@ -287,25 +309,11 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
# _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
end

return map(x -> (x, diffOp), meas)
# return meas, M
return map(x -> (x, M), meas)
end
# getDimension(oderel.domain)



function Base.show(io::IO, ::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}) where {T,O}
println(io, " DERelative{")
println(io, " ", T)
println(io, " ", O.name.name)
println(io, " }")
nothing
end

Base.show(io::IO, ::MIME"text/plain", der::DERelative) = show(io, der)

## the function
# ode.problem.f.f

#

end # module
2 changes: 1 addition & 1 deletion src/entities/ExtFactors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ struct DERelative{T <: InferenceVariable, P, D} <: AbstractManifoldMinimize # Ab
backwardProblem::P
""" second element of this data tuple is additional variables that will be passed down as a parameter """
data::D
specialSampler::Function
# specialSampler::Function
end
Loading