Skip to content

Commit

Permalink
Merge pull request #1778 from JuliaRobotics/23Q3/dev/parametric
Browse files Browse the repository at this point in the history
Refactor Parametric solve for better performance and better use of Manopt.jl
  • Loading branch information
Affie authored Oct 11, 2023
2 parents 525602e + 6f3e431 commit d2963ac
Show file tree
Hide file tree
Showing 24 changed files with 943 additions and 594 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ version = "0.34.1"
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down Expand Up @@ -67,7 +66,6 @@ IncrInfrInterpolationsExt = "Interpolations"
[compat]
ApproxManifoldProducts = "0.7, 0.8"
BSON = "0.2, 0.3"
BlockArrays = "0.16"
Combinatorics = "1.0"
DataStructures = "0.16, 0.17, 0.18"
DelimitedFiles = "1"
Expand Down
14 changes: 11 additions & 3 deletions src/Factors/GenericFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end

#::MeasurementOnTangent
function distanceTangent2Point(M::SemidirectProductGroup, X, p, q)
= Manifolds.compose(M, p, exp(M, identity_element(M, p), X)) #for groups
= Manifolds.compose(M, p, exp(M, getPointIdentity(M), X)) #for groups
# return log(M, q, q̂)
return vee(M, q, log(M, q, q̂))
# return distance(M, q, q̂)
Expand Down Expand Up @@ -96,7 +96,7 @@ end

# function (cf::CalcFactor{<:ManifoldFactor{<:AbstractDecoratorManifold}})(Xc, p, q)
function (cf::CalcFactor{<:ManifoldFactor})(X, p, q)
return distanceTangent2Point(cf.manifold, X, p, q)
return distanceTangent2Point(cf.factor.M, X, p, q)
end

## ======================================================================================
Expand Down Expand Up @@ -141,12 +141,20 @@ function getSample(cf::CalcFactor{<:ManifoldPrior})
return point
end

function getFactorMeasurementParametric(fac::ManifoldPrior)
M = getManifold(fac)
dims = manifold_dimension(M)
meas = fac.p
= convert(SMatrix{dims, dims}, invcov(fac.Z))
meas, iΣ
end

#TODO investigate SVector if small dims, this is slower
# dim = manifold_dimension(M)
# Xc = [SVector{dim}(rand(Z)) for _ in 1:N]

function (cf::CalcFactor{<:ManifoldPrior})(m, p)
M = cf.manifold # .factor.M
M = cf.factor.M
# return log(M, p, m)
return vee(M, p, log(M, p, m))
# return distancePrior(M, m, p)
Expand Down
13 changes: 11 additions & 2 deletions src/Factors/Mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ function sampleFactor(cf::CalcFactor{<:Mixture}, N::Int = 1)
## example case is old FluxModelsPose2Pose2 requiring velocity
# FIXME better consolidation of when to pass down .mechanics, also see #1099 and #1094 and #1069

cf_ = CalcFactor(
cf_ = CalcFactorNormSq(
cf.factor.mechanics,
0,
cf._legacyParams,
Expand All @@ -133,10 +133,19 @@ function sampleFactor(cf::CalcFactor{<:Mixture}, N::Int = 1)
#out memory should be right size first
length(cf.factor.labels) != N ? resize!(cf.factor.labels, N) : nothing
cf.factor.labels .= rand(cf.factor.diversity, N)
M = cf.manifold

# mixture needs to be refactored so let's make it worse :-)
if cf.factor.mechanics isa AbstractPrior
samplef = samplePoint
elseif cf.factor.mechanics isa AbstractRelative
samplef = sampleTangent
end

for i = 1:N
mixComponent = cf.factor.components[cf.factor.labels[i]]
# measurements relate to the factor's manifold (either tangent vector or manifold point)
setPointsMani!(smpls[i], rand(mixComponent, 1))
setPointsMani!(smpls, samplef(M, mixComponent), i)
end

# TODO only does first element of meas::Tuple at this stage, see #1099
Expand Down
1 change: 1 addition & 0 deletions src/Factors/MsgPrior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ end
getManifold(mp::MsgPrior{<:ManifoldKernelDensity}) = mp.Z.manifold
getManifold(mp::MsgPrior) = mp.M

#FIXME this will not work on manifolds
(cfo::CalcFactor{<:MsgPrior})(z, x1) = z .- x1

Base.@kwdef struct PackedMsgPrior <: AbstractPackedFactor
Expand Down
2 changes: 1 addition & 1 deletion src/IncrementalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ include("parametric/services/ConsolidateParametricRelatives.jl")
include("parametric/services/ParametricCSMFunctions.jl")
include("parametric/services/ParametricUtils.jl")
include("parametric/services/ParametricOptim.jl")
include("parametric/services/ParametricManoptDev.jl")
include("parametric/services/ParametricManopt.jl")
include("services/MaxMixture.jl")

#X-stroke
Expand Down
8 changes: 8 additions & 0 deletions src/entities/AliasScalarSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ struct AliasingScalarSampler
end
end

function sampleTangent(
M::AbstractDecoratorManifold, # stand-in type to restrict to just group manifolds
z::AliasingScalarSampler,
p = getPointIdentity(M),
)
return hat(M, p, SVector{manifold_dimension(M)}(rand(z)))
end

function rand!(ass::AliasingScalarSampler, smpls::Array{Float64})
StatsBase.alias_sample!(ass.domain, ass.weights, smpls)
return nothing
Expand Down
61 changes: 44 additions & 17 deletions src/entities/CalcFactor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
abstract type AbstractMaxMixtureSolver end


abstract type CalcFactor{T<:AbstractFactor} end


"""
$TYPEDEF
Expand All @@ -21,18 +24,19 @@ end
DevNotes
- Follow the Github project in IIF to better consolidate CCW FMD CPT CF CFM
- TODO CalcFactorNormSq is a step towards having a dedicated structure for non-parametric solve.
CalcFactorNormSq will calculate the Norm Squared of the factor.
Related
[`CalcFactorMahalanobis`](@ref), [`CommonConvWrapper`](@ref)
"""
struct CalcFactor{
struct CalcFactorNormSq{
FT <: AbstractFactor,
X,
C,
VT <: Tuple,
M <: AbstractManifold
}
} <: CalcFactor{FT}
""" the interface compliant user object functor containing the data and logic """
factor::FT
""" what is the sample (particle) id for which the residual is being calculated """
Expand All @@ -54,7 +58,15 @@ struct CalcFactor{
manifold::M
end


#TODO deprecate after CalcFactor is updated to CalcFactorNormSq
function CalcFactor(args...; kwargs...)
Base.depwarn(
"`CalcFactor` changed to an abstract type, use CalcFactorNormSq, CalcFactorMahalanobis, or CalcFactorResidual",
:CalcFactor
)

CalcFactorNormSq(args...; kwargs...)
end

"""
$TYPEDEF
Expand All @@ -65,32 +77,47 @@ Related
[`CalcFactor`](@ref)
"""
struct CalcFactorMahalanobis{N, D, L, S <: Union{Nothing, AbstractMaxMixtureSolver}}
struct CalcFactorMahalanobis{
FT,
N,
C,
MEAS<:AbstractArray,
D,
L,
S <: Union{Nothing, AbstractMaxMixtureSolver}
} <: CalcFactor{FT}
faclbl::Symbol
calcfactor!::CalcFactor
factor::FT
cache::C
varOrder::Vector{Symbol}
meas::NTuple{N, <:AbstractArray}
meas::NTuple{N, MEAS}
::NTuple{N, SMatrix{D, D, Float64, L}}
specialAlg::S
end




struct CalcFactorManopt{
struct CalcFactorResidual{
FT <: AbstractFactor,
C,
D,
L,
FT <: AbstractFactor,
M <: AbstractManifold,
P,
MEAS <: AbstractArray,
}
N
} <: CalcFactor{FT}
faclbl::Symbol
calcfactor!::CalcFactor{FT, Nothing, Nothing, Tuple{}, M}
varOrder::Vector{Symbol}
varOrderIdxs::Vector{Int}
factor::FT
cache::C
varOrder::NTuple{N, Symbol}
varOrderIdxs::NTuple{N, Int}
points::P #TODO remove or not?
meas::MEAS
::SMatrix{D, D, Float64, L}
::SMatrix{D, D, Float64, L} #TODO remove or not?
sqrt_iΣ::SMatrix{D, D, Float64, L}
end

_nvars(::CalcFactorResidual{FT, C, D, L, P, MEAS, N}) where {FT, C, D, L, P, MEAS, N} = N
# _typeof_meas(::CalcFactorManopt{FT, C, D, L, MEAS, N}) where {FT, C, D, L, MEAS, N} = MEAS
DFG.getDimension(::CalcFactorResidual{FT, C, D, L, P, MEAS, N}) where {FT, C, D, L, P, MEAS, N} = D


6 changes: 4 additions & 2 deletions src/manifolds/services/ManifoldSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ end
function sampleTangent(
M::AbstractDecoratorManifold,
z::Distribution,
p = identity_element(M), #getPointIdentity(M),
p = getPointIdentity(M),
)
return hat(M, p, rand(z, 1)[:]) #TODO find something better than (z,1)[:]
return hat(M, p, SVector{length(z)}(rand(z))) #TODO make sure all Distribution has length,
# if this errors maybe fall back no next line
# return convert(typeof(p), hat(M, p, rand(z, 1)[:])) #TODO find something better than (z,1)[:]
end

"""
Expand Down
18 changes: 9 additions & 9 deletions src/manifolds/services/ManifoldsExtentions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,54 +98,54 @@ end

import DistributedFactorGraphs: getPointIdentity

function getPointIdentity(G::ProductGroup, ::Type{T} = Float64) where {T <: Real}
function DFG.getPointIdentity(G::ProductGroup, ::Type{T} = Float64) where {T <: Real}
M = G.manifold
return ArrayPartition(map(x -> getPointIdentity(x, T), M.manifolds))
end

# fallback
function getPointIdentity(G::GroupManifold, ::Type{T} = Float64) where {T <: Real}
function DFG.getPointIdentity(G::GroupManifold, ::Type{T} = Float64) where {T <: Real}
return error("getPointIdentity not implemented on $G")
end

function getPointIdentity(
function DFG.getPointIdentity(
@nospecialize(G::ProductManifold),
::Type{T} = Float64,
) where {T <: Real}
return ArrayPartition(map(x -> getPointIdentity(x, T), G.manifolds))
end

function getPointIdentity(
function DFG.getPointIdentity(
@nospecialize(M::PowerManifold),
::Type{T} = Float64,
) where {T <: Real}
N = Manifolds.get_iterator(M).stop
return fill(getPointIdentity(M.manifold, T), N)
end

function getPointIdentity(M::NPowerManifold, ::Type{T} = Float64) where {T <: Real}
function DFG.getPointIdentity(M::NPowerManifold, ::Type{T} = Float64) where {T <: Real}
return fill(getPointIdentity(M.manifold, T), M.N)
end

function getPointIdentity(G::SemidirectProductGroup, ::Type{T} = Float64) where {T <: Real}
function DFG.getPointIdentity(G::SemidirectProductGroup, ::Type{T} = Float64) where {T <: Real}
M = base_manifold(G)
N, H = M.manifolds
np = getPointIdentity(N, T)
hp = getPointIdentity(H, T)
return ArrayPartition(np, hp)
end

function getPointIdentity(G::SpecialOrthogonal{N}, ::Type{T} = Float64) where {N, T <: Real}
function DFG.getPointIdentity(G::SpecialOrthogonal{N}, ::Type{T} = Float64) where {N, T <: Real}
return SMatrix{N, N, T}(I)
end

function getPointIdentity(
function DFG.getPointIdentity(
G::TranslationGroup{Tuple{N}},
::Type{T} = Float64,
) where {N, T <: Real}
return zeros(SVector{N,T})
end

function getPointIdentity(G::RealCircleGroup, ::Type{T} = Float64) where {T <: Real}
function DFG.getPointIdentity(G::RealCircleGroup, ::Type{T} = Float64) where {T <: Real}
return [zero(T)] #FIXME we cannot support scalars yet
end
Loading

0 comments on commit d2963ac

Please sign in to comment.