Skip to content

Commit

Permalink
Merge pull request #956 from JuliaRobotics/fix/4Q20/933
Browse files Browse the repository at this point in the history
standardize Mixture, fix #933
  • Loading branch information
dehann authored Oct 8, 2020
2 parents 0e665cf + dbfb167 commit b846e53
Show file tree
Hide file tree
Showing 14 changed files with 291 additions and 256 deletions.
2 changes: 1 addition & 1 deletion examples/BayesTreeIllustration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ addFactor!(fg, [:x0], Prior(Normal(0,1)))
addVariable!(fg, :x1, ContinuousScalar)
addFactor!(fg, [:x0, :x1], LinearRelative(Normal(10.0,1)))
addVariable!(fg, :x2, ContinuousScalar)
mmo = MixtureRelative(LinearRelative, [Rayleigh(3); Uniform(30,55)], Categorical([0.4; 0.6]))
mmo = Mixture(LinearRelative, [Rayleigh(3); Uniform(30,55)], Categorical([0.4; 0.6]))
addFactor!(fg, [:x1, :x2], mmo)


Expand Down
2 changes: 1 addition & 1 deletion examples/IllustrateAutoInit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ plotKDE(fg, [:x0, :x1])
# add another node, but introduce more general beliefs
addVariable!(fg, :x2, ContinuousScalar)

mmo = MixtureRelative(LinearRelative, [Rayleigh(3); Uniform(30,55)], Categorical([0.4; 0.6]))
mmo = Mixture(LinearRelative, [Rayleigh(3); Uniform(30,55)], Categorical([0.4; 0.6]))
addFactor!(fg, [:x1, :x2], mmo)

# Graphs.plot(fg.g)
Expand Down
99 changes: 78 additions & 21 deletions src/ApproxConv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function approxConvOnElements!( ccwl::CommonConvWrapper{T},
# ccwl.thrid_ = Threads.threadid()
ccwl.cpt[Threads.threadid()].particleidx = n
# ccall(:jl_, Nothing, (Any,), "starting loop, thrid_=$(Threads.threadid()), partidx=$(ccwl.cpt[Threads.threadid()].particleidx)")
numericRootGenericRandomizedFnc!( ccwl )
numericSolutionCCW!( ccwl )
end
nothing
end
Expand All @@ -36,12 +36,12 @@ Future work:
- improve handling of n and particleidx, especially considering future multithreading support
"""
function approxConvOnElements!( ccwl::CommonConvWrapper{T},
elements::Union{Vector{Int}, UnitRange{Int}}, ::Type{SingleThreaded}) where {T <: AbstractRelative}
function approxConvOnElements!( ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}},
elements::Union{Vector{Int}, UnitRange{Int}}, ::Type{SingleThreaded}) where {N_,F<:AbstractRelative,S,T}
#
for n in elements
ccwl.cpt[Threads.threadid()].particleidx = n
numericRootGenericRandomizedFnc!( ccwl )
numericSolutionCCW!( ccwl ) # numericRootGenericRandomizedFnc!
end
nothing
end
Expand All @@ -60,8 +60,8 @@ Future work:
- improve handling of n and particleidx, especially considering future multithreading support
"""
function approxConvOnElements!( ccwl::CommonConvWrapper{T},
elements::Union{Vector{Int}, UnitRange{Int}} ) where {T <: AbstractRelative}
function approxConvOnElements!( ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}}, #CommonConvWrapper{T},
elements::Union{Vector{Int}, UnitRange{Int}} ) where {N_,F<:AbstractRelative,S,T}
#
approxConvOnElements!(ccwl, elements, ccwl.threadmodel)
end
Expand All @@ -73,11 +73,12 @@ end
Prepare a common functor computation object `prepareCommonConvWrapper{T}` containing the user factor functor along with additional variables and information using during approximate convolution computations.
"""
function prepareCommonConvWrapper!( ccwl::CommonConvWrapper{T},
function prepareCommonConvWrapper!( F_::Type{<:AbstractRelative},
ccwl::CommonConvWrapper{F},
Xi::Vector{DFGVariable},
solvefor::Symbol,
N::Int;
solveKey::Symbol=:default ) where {T <: AbstractRelative}
solveKey::Symbol=:default ) where {F <: FunctorInferenceType}
#
ARR = Array{Array{Float64,2},1}()
# FIXME maxlen should parrot N (barring multi-/nullhypo issues)
Expand Down Expand Up @@ -112,6 +113,16 @@ function prepareCommonConvWrapper!( ccwl::CommonConvWrapper{T},
return sfidx, maxlen, manis
end


function prepareCommonConvWrapper!( ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}},
Xi::Vector{DFGVariable},
solvefor::Symbol,
N::Int;
solveKey::Symbol=:default ) where {N_,F<:AbstractRelative,S,T}
#
prepareCommonConvWrapper!(F, ccwl, Xi, solvefor, N, solveKey=solveKey)
end

function generateNullhypoEntropy( val::AbstractMatrix{<:Real},
maxlen::Int,
spreadfactor::Real=10 )
Expand Down Expand Up @@ -193,14 +204,15 @@ end
Common function to compute across a single user defined multi-hypothesis ambiguity per factor. This function dispatches both `AbstractRelativeFactor` and `AbstractRelativeFactorMinimize` factors.
"""
function computeAcrossHypothesis!(ccwl::CommonConvWrapper{T},
function computeAcrossHypothesis!(ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}},
allelements,
activehypo,
certainidx::Vector{Int},
sfidx::Int,
maxlen::Int,
maniAddOps::Tuple;
spreadNH::Float64=3.0 ) where {T <:AbstractRelative}
spreadNH::Real=3.0 ) where {N_,F<:AbstractRelative,S,T}
#
count = 0
# TODO remove assert once all GenericWrapParam has been removed
# @assert norm(ccwl.certainhypo - certainidx) < 1e-6
Expand All @@ -220,7 +232,7 @@ function computeAcrossHypothesis!(ccwl::CommonConvWrapper{T},
# multihypo, take other value case
# sfidx=2, hypoidx=3: 2 should take a value from 3
# sfidx=3, hypoidx=2: 3 should take a value from 2
# DEBUG sfidx=2, hypoidx=1 -- bad when do something like multihypo=[0.5;0.5] -- issue 424
# DEBUG sfidx=2, hypoidx=1 -- bad when do something like multihypo=[0.5;0.5] -- issue 424
ccwl.params[sfidx][:,allelements[count]] = view(ccwl.params[hypoidx],:,allelements[count])
elseif hypoidx == 0
# basically do nothing since the factor is not active for these allelements[count]
Expand All @@ -245,6 +257,7 @@ function computeAcrossHypothesis!(ccwl::CommonConvWrapper{T},
end



"""
$(SIGNATURES)
Expand All @@ -255,11 +268,12 @@ Planned changes will fold null hypothesis in as a standard feature and no longer
function evalPotentialSpecific( Xi::Vector{DFGVariable},
ccwl::CommonConvWrapper{T},
solvefor::Symbol,
T_::Type{<:AbstractRelative},
measurement::Tuple=(zeros(0,100),);
solveKey::Symbol=:default,
N::Int=size(measurement[1],2),
spreadNH::Real=3.0,
dbg::Bool=false ) where {T <: AbstractRelative}
dbg::Bool=false ) where {T <: FunctorInferenceType}
#

# Prep computation variables
Expand Down Expand Up @@ -290,11 +304,12 @@ end
function evalPotentialSpecific( Xi::Vector{DFGVariable},
ccwl::CommonConvWrapper{T},
solvefor::Symbol,
T_::Type{<:AbstractPrior},
measurement::Tuple=(zeros(0,0),);
solveKey::Symbol=:default,
N::Int=size(measurement[1],2),
dbg::Bool=false,
spreadNH::Float64=3.0 ) where {T <: AbstractPrior}
spreadNH::Real=3.0 ) where {T <: FunctorInferenceType}
#
# FIXME, NEEDS TO BE CLEANED UP AND WORK ON MANIFOLDS PROPER
fnc = ccwl.usrfnc!
Expand Down Expand Up @@ -347,18 +362,60 @@ function evalPotentialSpecific( Xi::Vector{DFGVariable},
return addEntr
end


function evalPotentialSpecific( Xi::Vector{DFGVariable},
ccwl::CommonConvWrapper{Mixture{N_,F,S,T}},
solvefor::Symbol,
measurement::Tuple=(zeros(0,0),);
solveKey::Symbol=:default,
N::Int=size(measurement[1],2),
dbg::Bool=false,
spreadNH::Real=3.0 ) where {N_,F<:FunctorInferenceType,S,T}
#
evalPotentialSpecific(Xi,
ccwl,
solvefor,
F,
measurement;
solveKey=solveKey,
N=N,
dbg=dbg,
spreadNH=spreadNH )
end


function evalPotentialSpecific( Xi::Vector{DFGVariable},
ccwl::CommonConvWrapper{F},
solvefor::Symbol,
measurement::Tuple=(zeros(0,0),);
solveKey::Symbol=:default,
N::Int=size(measurement[1],2),
dbg::Bool=false,
spreadNH::Real=3.0 ) where {F <: FunctorInferenceType}
#
evalPotentialSpecific(Xi,
ccwl,
solvefor,
F,
measurement;
solveKey=solveKey,
N=N,
dbg=dbg,
spreadNH=spreadNH )
end

"""
$(SIGNATURES)
Single entry point for evaluating factors from factor graph, using multiple dispatch to locate the correct `evalPotentialSpecific` function.
"""
function evalFactor2(dfg::AbstractDFG,
fct::DFGFactor,
solvefor::Symbol,
measurement::Tuple=(zeros(0,100),);
solveKey::Symbol=:default,
N::Int=size(measurement[1],2),
dbg::Bool=false )
function evalFactor2( dfg::AbstractDFG,
fct::DFGFactor,
solvefor::Symbol,
measurement::Tuple=(zeros(0,100),);
solveKey::Symbol=:default,
N::Int=size(measurement[1],2),
dbg::Bool=false )
#

ccw = getSolverData(fct).fnc
Expand Down Expand Up @@ -436,7 +493,7 @@ function approxConvBinary(arr::Array{Float64,2},

for n in 1:N
ccw.cpt[Threads.threadid()].particleidx = n
numericRootGenericRandomizedFnc!( ccw )
numericSolutionCCW!( ccw )
end
return pts
end
Expand Down
72 changes: 65 additions & 7 deletions src/Deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function findRelatedFromPotential(dfg::AbstractDFG,
dbg::Bool=false;
solveKey::Symbol=:default )
#
@warn("findRelatedFromPotential is obsolete, use `productbelief(fg, variableSym, :)`", maxlog=1)
@warn("findRelatedFromPotential likely to be deprecated, use `lsf` or `productbelief(fg, variableSym, ...) instead`", maxlog=1)

# assuming it is properly initialized TODO
ptsbw = evalFactor2(dfg, fct, varid, solveKey=solveKey, N=N, dbg=dbg);
Expand Down Expand Up @@ -77,6 +77,67 @@ end



# function prepareCommonConvWrapper!( ccwl::CommonConvWrapper{F},
# Xi::Vector{DFGVariable},
# solvefor::Symbol,
# N::Int;
# solveKey::Symbol=:default ) where {F <: AbstractRelative}
# #
# prepareCommonConvWrapper!(F, ccwl, Xi, solvefor, N, solveKey=solveKey)
# end

# function computeAcrossHypothesis!(ccwl::Union{CommonConvWrapper{F},CommonConvWrapper{Mixture{N_,F,S,T}}},
# allelements,
# activehypo,
# certainidx::Vector{Int},
# sfidx::Int,
# maxlen::Int,
# maniAddOps::Tuple;
# spreadNH::Real=3.0 ) where {N_,F<:AbstractRelative,S,T}
# #
# computeAcrossHypothesis!(F,ccwl,allelements,activehypo,certainidx,sfidx, maxlen,maniAddOps,spreadNH=spreadNH)
# end


# function computeAcrossHypothesis!(ccwl::CommonConvWrapper{Mixture{N_,F,S,T}},
# allelements,
# activehypo,
# certainidx::Vector{Int},
# sfidx::Int,
# maxlen::Int,
# maniAddOps::Tuple;
# spreadNH::Real=3.0 ) where {N_,F<:AbstractRelative,S,T}
# #
# computeAcrossHypothesis!(F,ccwl,allelements,activehypo,certainidx,sfidx, maxlen,maniAddOps,spreadNH=spreadNH)
# end


@deprecate numericRootGenericRandomizedFnc!(w...;kw...) numericSolutionCCW!(w...;kw...)


# function numericRootGenericRandomizedFnc!(ccwl::CommonConvWrapper{Mixture{N,F,S,T}};
# perturb::Float64=1e-10,
# testshuffle::Bool=false ) where
# {N,F<:AbstractRelative,S,T <: Tuple}
# #
# _numericSolutionCCW!(F, ccwl,perturb=perturb, testshuffle=testshuffle)
# end


# function numericRootGenericRandomizedFnc!(ccwl::CommonConvWrapper{F};
# perturb::Float64=1e-10,
# testshuffle::Bool=false ) where
# {F <: AbstractRelative}
# #
# _numericSolutionCCW!(F, ccwl, perturb=perturb, testshuffle=testshuffle)
# end


@deprecate MixtureRelative(w...; kw...) Mixture(w...; kw...)

@deprecate MixturePrior(w...; kw...) Mixture(Prior, w...; kw...)


# function areSiblingsRemaingNeedDownOnly(tree::AbstractBayesTree,
# cliq::TreeClique )::Bool
# #
Expand Down Expand Up @@ -564,12 +625,9 @@ end

@deprecate getMsgDwnChannel(tree::AbstractBayesTree, edge) getDwnMsgConsolidated(tree, edge)

export MixtureLinearConditional
# export MixtureLinearConditional

function MixtureLinearConditional(Z::AbstractVector{T}, C::DiscreteNonParametric) where T <: SamplableBelief
@warn("MixtureLinearConditional is deprecated, use `MixtureRelative(LinearConditional(LinearAlgebra.I), Z, C)` instead.")
MixtureRelative(LinearConditional(LinearAlgebra.I), Z, C)
end
@deprecate MixtureLinearConditional(Z::AbstractVector{<:SamplableBelief}, C::DiscreteNonParametric) Mixture(LinearRelative, Z, C)


"""
Expand Down Expand Up @@ -809,7 +867,7 @@ end

# getSample(s::MixtureRelative, N::Int=1) = (reshape.(rand.(s.Z, N),1,:)..., rand(s.C, N))

@deprecate (MixturePrior{T}(z::NTuple{N,<:SamplableBelief}, c::Union{<:Distributions.DiscreteNonParametric, NTuple{N,<:Real}, <:AbstractVector{<:Real}} ) where {T,N}) MixturePrior(z,c)
# @deprecate (MixturePrior{T}(z::NTuple{N,<:SamplableBelief}, c::Union{<:Distributions.DiscreteNonParametric, NTuple{N,<:Real}, <:AbstractVector{<:Real}} ) where {T,N}) MixturePrior(z,c)

@deprecate LinearConditional(N::Int=1) LinearRelative{N}(LinearAlgebra.I)
# @deprecate LinearConditional(x::SamplableBelief) LinearRelative(x)
Expand Down
2 changes: 2 additions & 0 deletions src/Factors/DefaultPrior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ not recommended when non-Euclidean dimensions are used in variables.
struct Prior{T <: SamplableBelief} <: AbstractPrior
Z::T
end
Prior(::UniformScaling) = Prior(Normal())

getSample(s::Prior, N::Int=1) = (reshape(rand(s.Z,N),:,N), )


Expand Down
Loading

0 comments on commit b846e53

Please sign in to comment.