Skip to content

Commit

Permalink
test partial prior construction and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dehann committed Apr 10, 2021
1 parent df87f46 commit 94070a8
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 67 deletions.
65 changes: 44 additions & 21 deletions src/ApproxConv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function prepareCommonConvWrapper!( F_::Type{<:AbstractRelative},
solvefor::Symbol,
N::Int;
needFreshMeasurements::Bool=true,
solveKey::Symbol=:default ) where {F <: FunctorInferenceType}
solveKey::Symbol=:default ) where {F <: AbstractFactor}
#

# FIXME, order of fmd ccwl cf are a little weird and should be revised.
Expand Down Expand Up @@ -273,10 +273,6 @@ function computeAcrossHypothesis!(ccwl::Union{<:CommonConvWrapper{F},
count = 0

cpt_ = ccwl.cpt[Threads.threadid()]

# setup the partial or complete decision variable dimensions for this ccwl object
# NOTE perhaps deconv has changed the decision variable list, so placed here during consolidation phase
_setCCWDecisionDimsConv!(ccwl)

# @assert norm(ccwl.certainhypo - certainidx) < 1e-6
for (hypoidx, vars) in activehypo
Expand Down Expand Up @@ -363,7 +359,11 @@ function evalPotentialSpecific( Xi::AbstractVector{<:DFGVariable},
if 0 < size(measurement[1],1)
ccwl.measurement = measurement
end


# setup the partial or complete decision variable dimensions for this ccwl object
# NOTE perhaps deconv has changed the decision variable list, so placed here during consolidation phase
_setCCWDecisionDimsConv!(ccwl)

# Check which variables have been initialized
isinit = map(x->isInitialized(x), Xi)

Expand All @@ -375,7 +375,7 @@ function evalPotentialSpecific( Xi::AbstractVector{<:DFGVariable},
# TODO convert to HypothesisRecipeElements result
_, allelements, activehypo, mhidx = assembleHypothesesElements!(ccwl.hypotheses, maxlen, sfidx, length(Xi), isinit, ccwl.nullhypo )
certainidx = ccwl.certainhypo

# perform the numeric solutions on the indicated elements
# error("ccwl.xDim=$(ccwl.xDim)")
# FIXME consider repeat solve as workaround for inflation off-zero
Expand All @@ -400,6 +400,10 @@ function evalPotentialSpecific( Xi::AbstractVector{<:DFGVariable},
inflateCycles::Int=3,
skipSolve::Bool=false ) where {T <: AbstractFactor}
#
# setup the partial or complete decision variable dimensions for this ccwl object
# NOTE perhaps deconv has changed the decision variable list, so placed here during consolidation phase
_setCCWDecisionDimsConv!(ccwl)

# FIXME, NEEDS TO BE CLEANED UP AND WORK ON MANIFOLDS PROPER
fnc = ccwl.usrfnc!
sfidx = 1
Expand Down Expand Up @@ -464,7 +468,7 @@ function evalPotentialSpecific( Xi::AbstractVector{<:DFGVariable},
dbg::Bool=false,
spreadNH::Real=3.0,
inflateCycles::Int=3,
skipSolve::Bool=false ) where {N_,F<:FunctorInferenceType,S,T}
skipSolve::Bool=false ) where {N_,F<:AbstractFactor,S,T}
#
evalPotentialSpecific(Xi,
ccwl,
Expand All @@ -491,7 +495,7 @@ function evalPotentialSpecific( Xi::AbstractVector{<:DFGVariable},
dbg::Bool=false,
spreadNH::Real=3.0,
inflateCycles::Int=3,
skipSolve::Bool=false ) where {F <: FunctorInferenceType}
skipSolve::Bool=false ) where {F <: AbstractFactor}
#
evalPotentialSpecific(Xi,
ccwl,
Expand Down Expand Up @@ -689,7 +693,7 @@ end
# TODO should this be consolidated with regular approxConv?
# TODO, perhaps pass Xi::Vector{DFGVariable} instead?
function approxConvBinary(arr::Array{Float64,2},
meas::FunctorInferenceType,
meas::AbstractFactor,
outdims::Int,
fmd::FactorMetadata,
measurement::Tuple=(zeros(0,size(arr,2)),);
Expand Down Expand Up @@ -814,6 +818,10 @@ function findRelatedFromPotential(dfg::AbstractDFG,
end


function _expandNamedTupleType()

end


"""
$SIGNATURES
Expand All @@ -828,27 +836,42 @@ function proposalbeliefs!(dfg::AbstractDFG,
destvertlabel::Symbol,
factors::AbstractVector{<:DFGFactor},
dens::Vector{BallTreeDensity},
partials::Dict{Int, Vector{BallTreeDensity}},
partials::Dict{Any, Vector{BallTreeDensity}}, # TODO change this structure
measurement::Tuple=(zeros(0,0),);
solveKey::Symbol=:default,
N::Int=100,
dbg::Bool=false )
#


# group partial dimension factors by selected dimensions -- i.e. [(1,)], [(1,2),(1,2)], [(2,);(2;)]


# populate the full and partial dim containers
inferddimproposal = Vector{Float64}(undef, length(factors))
for (count,fct) in enumerate(factors)
data = getSolverData(fct)
p, inferd = findRelatedFromPotential(dfg, fct, destvertlabel, measurement, N=N, dbg=dbg, solveKey=solveKey)
if _getCCW(data).partial # partial density
pardims = _getDimensionsPartial(_getCCW(data)) # _getCCW(data).usrfnc!.partial
for dimnum in pardims
if haskey(partials, dimnum)
push!(partials[dimnum], marginal(p,Int.([dimnum;])))
else
partials[dimnum] = BallTreeDensity[marginal(p,Int.([dimnum;]))]
end
ccwl = _getCCW(data)
propBel, inferd = findRelatedFromPotential(dfg, fct, destvertlabel, measurement, N=N, dbg=dbg, solveKey=solveKey)
if isPartial(ccwl) # partial density # ccwl.partial
pardims = _getDimensionsPartial(ccwl) # _getCCW(data).usrfnc!.partial
@assert [getFactorType(fct).partial...] == [pardims...] "partial dims error $(getFactorType(fct).partial) vs $pardims"

marg_ = marginal(propBel, Int[pardims...])
if haskey(partials, pardims)
push!(partials[pardims], marg_)
else
partials[pardims] = BallTreeDensity[marg_;]
end
# for dimnum in pardims
# if haskey(partials, dimnum)
# push!(partials[dimnum], marginal(propBel, Int.([dimnum;])))
# else
# partials[dimnum] = BallTreeDensity[marginal(propBel, Int.([dimnum;]))]
# end
# end
else # add onto full density list
push!(dens, p)
push!(dens, propBel)
end
inferddimproposal[count] = inferd
end
Expand Down
4 changes: 2 additions & 2 deletions src/CalcFactor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ New generation user factor interface method for computing the residual values of
Notes
- Under development and still experimental. Expected to become default method in IIF v0.20.0
"""
struct CalcFactor{T <: FunctorInferenceType, M, P <: Tuple, X <: AbstractVector}
struct CalcFactor{T <: AbstractFactor, M, P <: Tuple, X <: AbstractVector}
# the interface compliant user object functor containing the data and logic
factor::T
# the metadata to be passed to the user residual function
Expand Down Expand Up @@ -49,7 +49,7 @@ DevNotes
- Use in place operations where possible and remember `measurement` is a `::Tuple`.
- TODO only works on `.threadid()==1` at present, see #1094
"""
function sampleFactor(cf::CalcFactor{<:FunctorInferenceType},
function sampleFactor(cf::CalcFactor{<:AbstractFactor},
N::Int=1 )
#
getSample(cf, N)
Expand Down
4 changes: 2 additions & 2 deletions src/DispatchPackedConversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Base.convert(::Type{<:SamplableBelief}, mkd::PackedManifoldKernelDensity) = conv



function packmultihypo(fnc::CommonConvWrapper{T}) where {T<:FunctorInferenceType}
function packmultihypo(fnc::CommonConvWrapper{T}) where {T<:AbstractFactor}
@warn "packmultihypo is deprecated in favor of Vector only operations"
fnc.hypotheses !== nothing ? string(fnc.hypotheses) : ""
end
Expand Down Expand Up @@ -42,7 +42,7 @@ end

function convert(
::Type{GenericFunctionNodeData{CommonConvWrapper{F}}},
packed::GenericFunctionNodeData{P} ) where {F <: FunctorInferenceType, P <: PackedInferenceType}
packed::GenericFunctionNodeData{P} ) where {F <: AbstractFactor, P <: PackedInferenceType}
#
# TODO store threadmodel=MutliThreaded,SingleThreaded in persistence layer
usrfnc = convert(F, packed.fnc)
Expand Down
6 changes: 6 additions & 0 deletions src/FactorGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,11 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},

# sort out partialDims here
ispartl = hasfield(T, :partial)
partialDims = if ispartl
Int[usrfnc.partial...]
else
Int[]
end

ccw = CommonConvWrapper(
usrfnc,
Expand All @@ -655,6 +660,7 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
nullhypo=nullhypo,
threadmodel=threadmodel,
inflation=inflation,
partialDims=partialDims
)
#
return ccw
Expand Down
6 changes: 3 additions & 3 deletions src/FactorGraphTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ function ConvPerThread( X::Array{Float64,2},
factormetadata::FactorMetadata;
particleidx::Int=1,
activehypo= 1:length(params),
p=collect(1:size(X,1)),
p::AbstractVector{<:Integer}=collect(1:size(X,1)),
perturb=zeros(zDim),
res=zeros(zDim),
thrid_ = 0 )
Expand All @@ -209,7 +209,7 @@ function ConvPerThread( X::Array{Float64,2},
particleidx,
factormetadata,
Int[activehypo;],
[p...;],
Int[p...;],
perturb,
X,
res )
Expand Down Expand Up @@ -267,7 +267,7 @@ function CommonConvWrapper( fnc::T,
measurement::Tuple=(zeros(0,1),),
particleidx::Int=1,
xDim::Int=size(X,1),
partialDims=collect(1:size(X,1)), # TODO make this SVector, and name partialDims
partialDims::AbstractVector{<:Integer}=collect(1:size(X,1)), # TODO make this SVector, and name partialDims
perturb=zeros(zDim),
res::AbstractVector{<:Real}=zeros(zDim),
threadmodel::Type{<:_AbstractThreadModel}=MultiThreaded,
Expand Down
77 changes: 43 additions & 34 deletions src/GraphProductOperations.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@


function _partialProducts!(pGM, partials, manis)
# whats up with this?
for (dimnum,pp) in partials
push!(pp, AMP.manikde!(pGM[dimnum:dimnum,:], (manis[dimnum],) ))
function _partialProducts!(pGM, partials, manis; useExisting::Bool=false)
if useExisting
# include previous calcs
for (dimnum,pp) in partials
dimv = [dimnum...]
push!(pp, AMP.manikde!(pGM[dimv,:], (manis[dimv]...,) ))
end
end

# do each partial dimension individually
for (dimnum,pp) in partials
pGM[dimnum,:] = AMP.manifoldProduct(pp, (manis[dimnum],), Niter=1) |> getPoints
dimv = [dimnum...]
pGM[dimv,:] = AMP.manifoldProduct(pp, (manis[dimv]...,), Niter=1) |> getPoints
end
end

"""
$(SIGNATURES)
# """
# $(SIGNATURES)

Multiply various full and partial dimension proposal densities.
# Multiply various full and partial dimension proposal densities.

DevNotes
- FIXME consolidate partial and full product AMP API, relates to #1010
- TODO better consolidate with full dimension product
- TODO -- reuse memory rather than rand here
"""
function prodmultiplefullpartials(dens::Vector{BallTreeDensity},
partials::Dict{Int, Vector{BallTreeDensity}},
Ndims::Int,
N::Int,
manis::Tuple )
#
# calculate products over all dimensions, legacy proposals held in `dens` vector
pGM = AMP.manifoldProduct(dens, manis, Niter=1) |> getPoints
# DevNotes
# - FIXME consolidate partial and full product AMP API, relates to #1010
# - TODO better consolidate with full dimension product
# - TODO -- reuse memory rather than rand here
# """
# function prodmultiplefullpartials(dens::Vector{BallTreeDensity},
# partials::Dict{Any, Vector{BallTreeDensity}},
# Ndims::Int,
# N::Int,
# manis::Tuple;
# useExisting::Bool=false )
# #
# # calculate products over all dimensions, legacy proposals held in `dens` vector
# pGM = AMP.manifoldProduct(dens, manis, Niter=1) |> getPoints

_partialProducts!(pGM, partials, manis)
# _partialProducts!(pGM, partials, manis, useExisting=useExisting)

return pGM
end
# return pGM
# end



Expand All @@ -52,7 +57,7 @@ Notes
function productbelief( dfg::AbstractDFG,
vertlabel::Symbol,
dens::Vector{<:BallTreeDensity},
partials::Dict{Int, <:AbstractVector{<:BallTreeDensity}},
partials::Dict{Any, <:AbstractVector{<:BallTreeDensity}},
N::Int;
dbg::Bool=false,
logger=ConsoleLogger() )
Expand All @@ -77,15 +82,19 @@ function productbelief( dfg::AbstractDFG,
# end

if 0 < lennonp # || (lennonp == 0 && 0 < lenpart)
# calculate products over all dimensions, legacy proposals held in `dens` vector
pGM = AMP.manifoldProduct(dens, manis, Niter=1) |> getPoints
# multiple non-partials
pGM = prodmultiplefullpartials(dens_, partials, Ndims, N, manis)
_partialProducts!(pGM, partials, manis, useExisting=true)
# pGM = prodmultiplefullpartials(dens_, partials, Ndims, N, manis, useExisting=true)
elseif lennonp == 0 && 0 < lenpart
# only partials, must get other existing values for vertlabel from dfg
pGM = deepcopy(denspts)
# do each partial dimension individually
for (dimnum,pp) in partials
pGM[dimnum,:] = AMP.manifoldProduct(pp, (manis[dimnum],), Niter=1) |> getPoints
end
_partialProducts!(pGM, partials, manis; useExisting=false)
# # do each partial dimension individually
# for (dimnum,pp) in partials
# pGM[dimnum,:] = AMP.manifoldProduct(pp, (manis[dimnum],), Niter=1) |> getPoints
# end
else
with_logger(logger) do
@warn "Unknown density product on variable=$(vert.label), lennonp=$(lennonp), lenpart=$(lenpart)"
Expand Down Expand Up @@ -114,7 +123,7 @@ function predictbelief( dfg::AbstractDFG,
dbg::Bool=false,
logger=ConsoleLogger(),
dens = Array{BallTreeDensity,1}(),
partials = Dict{Int, Vector{BallTreeDensity}}() )
partials = Dict{Any, Vector{BallTreeDensity}}() )
#

# determine number of particles to draw from the marginal
Expand All @@ -139,7 +148,7 @@ function predictbelief( dfg::AbstractDFG,
dbg::Bool=false,
logger=ConsoleLogger(),
dens = Array{BallTreeDensity,1}(),
partials = Dict{Int, Vector{BallTreeDensity}}() )
partials = Dict{Any, Vector{BallTreeDensity}}() )
#
factors = getFactor.(dfg, factorsyms)
vert = getVariable(dfg, destvertsym)
Expand All @@ -160,7 +169,7 @@ function predictbelief( dfg::AbstractDFG,
dbg::Bool=false,
logger=ConsoleLogger(),
dens = Array{BallTreeDensity,1}(),
partials = Dict{Int, Vector{BallTreeDensity}}() )
partials = Dict{Any, Vector{BallTreeDensity}}() )
#
predictbelief(dfg, destvertsym, getNeighbors(dfg, destvertsym), solveKey=solveKey, needFreshMeasurements=needFreshMeasurements, N=N, dbg=dbg, logger=logger, dens=dens, partials=partials )
end
Expand All @@ -185,7 +194,7 @@ function localProduct(dfg::AbstractDFG,

# # get proposal beliefs
dens = Array{BallTreeDensity,1}()
partials = Dict{Int, Vector{BallTreeDensity}}()
partials = Dict{Any, Vector{BallTreeDensity}}()
pGM, sinfd = predictbelief(dfg, sym, lb, solveKey=solveKey, logger=logger, dens=dens, partials=partials)

# make manifold belief from product
Expand Down
8 changes: 3 additions & 5 deletions src/JunctionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -973,11 +973,9 @@ end
Return `::Bool` on whether factor is a partial constraint.
"""
isPartial(fcf::T) where {T <: FunctorInferenceType} = :partial in fieldnames(T)
function isPartial(fct::DFGFactor) #fct::TreeClique
fcf = _getCCW(fct).usrfnc!
isPartial(fcf)
end
isPartial(fcf::T) where {T <: AbstractFactor} = :partial in fieldnames(T)
isPartial(ccw::CommonConvWrapper) = ccw.usrfnc! |> isPartial
isPartial(fct::DFGFactor) = _getCCW(fct) |> isPartial

"""
$SIGNATURES
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ include("testMixturePrior.jl")

include("testPartialFactors.jl")

include("testPartialPrior.jl")

include("testSaveLoadDFG.jl")

include("testJunctionTreeConstruction.jl")
Expand Down
Loading

0 comments on commit 94070a8

Please sign in to comment.