Skip to content

Commit

Permalink
half revert fixes to partialDims
Browse files Browse the repository at this point in the history
  • Loading branch information
dehann committed Apr 7, 2021
1 parent 61ab4f4 commit df87f46
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 36 deletions.
50 changes: 28 additions & 22 deletions src/ApproxConv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,40 @@ export findRelatedFromPotential
Internal method to set which dimensions should be used as the decision variables for later numerical optimization.
"""
function _setCCWDecisionDimsConv!(ccwl::Union{CommonConvWrapper{F},
CommonConvWrapper{Mixture{N_,F,S,T}}} ) where {N_,F<:Union{AbstractRelativeMinimize, AbstractPrior},S,T}
CommonConvWrapper{Mixture{N_,F,S,T}}} ) where {N_,F<:Union{AbstractRelativeMinimize, AbstractRelativeRoots, AbstractPrior},S,T}
#
# return nothing

p = if ccwl.partial
Int[ccwl.usrfnc!.partial...]
Int32[ccwl.usrfnc!.partial...]
else
Int[1:ccwl.xDim...]
Int32[1:ccwl.xDim...]
end

ccwl.partialDims = (p)
# NOTE should only be done in the constructor
# for thrid in 1:Threads.nthreads()
# length(ccwl.cpt[thrid].p) != length(p) ? resize!(ccwl.cpt[thrid].p, length(p)) : nothing
# ccwl.cpt[thrid].p .= p
# end
for thrid in 1:Threads.nthreads()
length(ccwl.cpt[thrid].p) != length(p) ? resize!(ccwl.cpt[thrid].p, length(p)) : nothing
ccwl.cpt[thrid].p .= p # SVector... , see ccw.partialDims
end
nothing
end

function _setCCWDecisionDimsConv!(ccwl::Union{CommonConvWrapper{F},
CommonConvWrapper{Mixture{N_,F,S,T}}} ) where {N_,F<:AbstractRelativeRoots,S,T}
#
return nothing

# # should be done with constructor only
# for thrid in 1:Threads.nthreads()
# length(ccwl.cpt[thrid].p) != ccwl.xDim ? resize!(ccwl.cpt[thrid].p, ccwl.xDim) : nothing
# ccwl.cpt[thrid].p .= Int[1:ccwl.xDim;]
# end
# nothing
end
# function _setCCWDecisionDimsConv!(ccwl::Union{CommonConvWrapper{F},
# CommonConvWrapper{Mixture{N_,F,S,T}}} ) where {N_,F<:AbstractRelativeRoots,S,T}
# #
# # return nothing

# p = Int[1:ccwl.xDim;]
# ccwl.partialDims = SVector(Int32.(p)...)

# # should be done with constructor only
# for thrid in 1:Threads.nthreads()
# # length(ccwl.cpt[thrid].p) != ccwl.xDim ? resize!(ccwl.cpt[thrid].p, ccwl.xDim) : nothing
# ccwl.cpt[thrid].p = p # SVector(Int32[1:ccwl.xDim;]...)
# end
# nothing
# end

"""
$(SIGNATURES)
Expand Down Expand Up @@ -833,12 +839,12 @@ function proposalbeliefs!(dfg::AbstractDFG,
data = getSolverData(fct)
p, inferd = findRelatedFromPotential(dfg, fct, destvertlabel, measurement, N=N, dbg=dbg, solveKey=solveKey)
if _getCCW(data).partial # partial density
pardims = _getCCW(data).usrfnc!.partial
pardims = _getDimensionsPartial(_getCCW(data)) # _getCCW(data).usrfnc!.partial
for dimnum in pardims
if haskey(partials, dimnum)
push!(partials[dimnum], marginal(p,[dimnum]))
push!(partials[dimnum], marginal(p,Int.([dimnum;])))
else
partials[dimnum] = BallTreeDensity[marginal(p,[dimnum])]
partials[dimnum] = BallTreeDensity[marginal(p,Int.([dimnum;]))]
end
end
else # add onto full density list
Expand Down
11 changes: 11 additions & 0 deletions src/FGOSUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ getFactorDim(w...) = getDimension(w...)
# getFactorDim(fc::DFGFactor) = getFactorDim(getSolverData(fc))
getFactorDim(fg::AbstractDFG, fctid::Symbol) = getFactorDim(getFactor(fg, fctid))



function _getDimensionsPartial(ccw::CommonConvWrapper)
# @warn "_getDimensionsPartial not ready for use yet"
ccw.partialDims
end
_getDimensionsPartial(data::GenericFunctionNodeData) = _getCCW(data) |> _getDimensionsPartial
_getDimensionsPartial(fct::DFGFactor) = _getDimensionsPartial(_getCCW(fct))
_getDimensionsPartial(fg::AbstractDFG, lbl::Symbol) = _getDimensionsPartial(getFactor(fg, lbl))


"""
$SIGNATURES
Get `.factormetadata` for each CPT in CCW for a specific factor in `fg`.
Expand Down
9 changes: 6 additions & 3 deletions src/FactorGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -639,19 +639,22 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
# zdim = T != GenericMarginal ? size(getSample(usrfnc, 2)[1],1) : 0
certainhypo = multihypo !== nothing ? collect(1:length(multihypo.p))[multihypo.p .== 0.0] : collect(1:length(Xi))

# sort out partialDims here
ispartl = hasfield(T, :partial)

ccw = CommonConvWrapper(
usrfnc,
zeros(1,0),
zdim,
ARR,
fmd,
specialzDim = sum(fldnms .== :zDim) >= 1,
partial = sum(fldnms .== :partial) >= 1,
specialzDim = hasfield(T, :zDim),
partial = ispartl,
hypotheses=multihypo,
certainhypo=certainhypo,
nullhypo=nullhypo,
threadmodel=threadmodel,
inflation=inflation
inflation=inflation,
)
#
return ccw
Expand Down
21 changes: 10 additions & 11 deletions src/FactorGraphTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ DevNotes
- TODO make static params {XDIM, ZDIM, P}
- TODO make immutable
"""
mutable struct ConvPerThread{R,F<:FactorMetadata, N}
mutable struct ConvPerThread{R,F<:FactorMetadata}
thrid_::Int
# the actual particle being solved at this moment
particleidx::Int
Expand All @@ -185,7 +185,7 @@ mutable struct ConvPerThread{R,F<:FactorMetadata, N}
# subsection indices to select which params should be used for this hypothesis evaluation
activehypo::Vector{Int}
# Select which decision variables to include in a particular optimization run
p::SVector{N, Int32} # Vector{Int}
p::Vector{Int}
# slight numerical perturbation for degenerate solver cases such as division by zero
perturb::Vector{Float64}
# working memory location for optimization routines on target decision variables
Expand All @@ -209,10 +209,10 @@ function ConvPerThread( X::Array{Float64,2},
particleidx,
factormetadata,
Int[activehypo;],
SVector(Int32.(p)...),
[p...;],
perturb,
X,
res)
res )
end


Expand All @@ -222,8 +222,7 @@ $(TYPEDEF)
"""
mutable struct CommonConvWrapper{ T<:FunctorInferenceType,
H<:Union{Nothing, Distributions.Categorical},
C<:Union{Nothing, Vector{Int}},
N } <: FactorOperationalMemory
C<:Union{Nothing, Vector{Int}} } <: FactorOperationalMemory
#
### Values consistent across all threads during approx convolution
usrfnc!::T # user factor / function
Expand All @@ -248,8 +247,8 @@ mutable struct CommonConvWrapper{ T<:FunctorInferenceType,
cpt::Vector{<:ConvPerThread}
# inflationSpread
inflation::Float64
# which dimensions does this factor influence
partialDims::SVector{N, Int32}
# DONT USE THIS YET which dimensions does this factor influence
partialDims::Vector{Int} # should become SVector{N, Int32}
end


Expand All @@ -268,7 +267,7 @@ function CommonConvWrapper( fnc::T,
measurement::Tuple=(zeros(0,1),),
particleidx::Int=1,
xDim::Int=size(X,1),
p=collect(1:xDim), # TODO make this SVector, and name partialDims
partialDims=collect(1:size(X,1)), # TODO make this SVector, and name partialDims
perturb=zeros(zDim),
res::AbstractVector{<:Real}=zeros(zDim),
threadmodel::Type{<:_AbstractThreadModel}=MultiThreaded,
Expand All @@ -287,10 +286,10 @@ function CommonConvWrapper( fnc::T,
measurement,
threadmodel,
(i->ConvPerThread(X, zDim,factormetadata, particleidx=particleidx,
activehypo=activehypo, p=p,
activehypo=activehypo, p=partialDims,
perturb=perturb, res=res )).(1:Threads.nthreads()),
inflation,
SVector(Int32.(p)...)
partialDims # SVector(Int32.()...)
)
end

Expand Down

0 comments on commit df87f46

Please sign in to comment.