Skip to content

Commit

Permalink
Merge pull request #1288 from JuliaRobotics/21Q3/enh/vecpproper
Browse files Browse the repository at this point in the history
wip, Vector{P} proper
  • Loading branch information
dehann authored Jul 4, 2021
2 parents 855b2bb + a46a27e commit 4d7cabc
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 85 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
ApproxManifoldProducts = "0.4"
ApproxManifoldProducts = "0.4.1"
BSON = "0.2, 0.3"
Combinatorics = "1.0"
DataStructures = "0.16, 0.17, 0.18"
Expand Down
2 changes: 1 addition & 1 deletion src/ApproxConv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ function findRelatedFromPotential(dfg::AbstractDFG,
Npoints = length(pts) # size(pts,2)
# Assume we only have large particle population sizes, thanks to addNode!
M = getManifold(getVariableType(dfg, target))
proposal = AMP.manikde!(pts, M)
proposal = AMP.manikde!(M, pts)

# FIXME consolidate with approxConv method instead
if Npoints != N # this is where we control the overall particle set size
Expand Down
2 changes: 1 addition & 1 deletion src/BeliefTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ TreeBelief( val::AbstractVector{P},
inferdim::Real=0,
variableType::T=ContinuousScalar(),
manifold::M=getManifold(variableType),
solvableDim::Real=0) where {P <: AbstractVector, T <: InferenceVariable, M <:MB.AbstractManifold} = TreeBelief{T,P,M}(val, bw, inferdim, variableType, manifold, solvableDim)
solvableDim::Real=0) where {P, T <: InferenceVariable, M <:MB.AbstractManifold} = TreeBelief{T,P,M}(val, bw, inferdim, variableType, manifold, solvableDim)

function TreeBelief(vnd::VariableNodeData{T}, solvDim::Real=0) where T
TreeBelief( vnd.val, vnd.bw, vnd.inferdim, getVariableType(vnd), getManifold(T), solvDim )
Expand Down
9 changes: 6 additions & 3 deletions src/CompareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ function Base.isapprox( p1::Union{<:BallTreeDensity, <:ManifoldKernelDensity},
mmd(p1,p2) < atol
end

function compareAllSpecial(A::T1, B::T2;
skip=Symbol[], show::Bool=true) where {T1 <: CommonConvWrapper, T2 <: CommonConvWrapper}
function compareAllSpecial( A::T1, B::T2;
skip=Symbol[], show::Bool=true) where {T1 <: CommonConvWrapper, T2 <: CommonConvWrapper}
#
T1 != T2 && return false
if T1 != T2
@warn "CCW types not equal" T1 T2
return false
end
return compareAll(A, B, skip=skip, show=show)
end

Expand Down
3 changes: 2 additions & 1 deletion src/DeconvUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ function approxDeconv(fcto::DFGFactor,
fmd = _getFMdThread(ccw)

# TODO assuming vector on only first container in measurement::Tuple
makeTarget = (i) -> view(measurement[1][i],:)
makeTarget = (i) -> measurement[1][i] # TODO does not support copy-primitive types like Float64, only Ref()
# makeTarget = (i) -> view(measurement[1][i],:)
# makeTarget = (i) -> view(measurement[1], :, i)

# NOTE
Expand Down
5 changes: 5 additions & 0 deletions src/Deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ end
## Deprecate code below before v0.27
##==============================================================================

# getManifold(::InstanceType{LinearRelative{N}}) where {N} = Euclidean(N)
# getManifolds(::T) where {T <: LinearRelative} = convert(Tuple, getManifold(T))
# getManifolds(::Type{<:T}) where {T <: LinearRelative} = convert(Tuple, getManifold(T))
# getManifolds(fctType::Type{LinearRelative}) = getManifolds(getDomain(fctType))

# # FIXME, why is Manifolds depdendent on the solveKey?? Should just be at DFGVariable level?

# getManifolds(vd::VariableNodeData) = getVariableType(vd) |> getManifolds
Expand Down
4 changes: 2 additions & 2 deletions src/FGOSUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ function manikde!(pts::AbstractVector{P},
variableType::Union{InstanceType{<:InferenceVariable}, InstanceType{<:AbstractFactor}} ) where P
#
M = getManifold(variableType)
return AMP.manikde!(pts, bws, M)
return AMP.manikde!(M, pts, bws)
end

function manikde!(pts::AbstractVector{P},
vartype::Union{InstanceType{<:InferenceVariable}, InstanceType{<:AbstractFactor}}) where P
#
M = getManifold(vartype)
return AMP.manikde!(pts, M)
return AMP.manikde!(M, pts)
end

# manikde!( pts::AbstractVector{<:Real},
Expand Down
74 changes: 36 additions & 38 deletions src/FactorGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,33 @@ function setBW!(v::DFGVariable, bw::Array{Float64,2}; solveKey::Symbol=:default)
nothing
end

function setVal!(vd::VariableNodeData, val::AbstractVector{P}) where P <:AbstractVector
function setVal!(vd::VariableNodeData, val::AbstractVector{P}) where P
vd.val = val
nothing
end
function setVal!(v::DFGVariable, val::AbstractVector{P}; solveKey::Symbol=:default) where P <:AbstractVector
function setVal!(v::DFGVariable, val::AbstractVector{P}; solveKey::Symbol=:default) where P
setVal!(getSolverData(v, solveKey), val)
nothing
end
function setVal!(vd::VariableNodeData, val::AbstractVector{P}, bw::Array{Float64,2}) where P <:AbstractVector
function setVal!(vd::VariableNodeData, val::AbstractVector{P}, bw::Array{Float64,2}) where P
setVal!(vd, val)
setBW!(vd, bw)
nothing
end
function setVal!(v::DFGVariable, val::AbstractVector{P}, bw::Array{Float64,2}; solveKey::Symbol=:default) where P <:AbstractVector
function setVal!(v::DFGVariable, val::AbstractVector{P}, bw::Array{Float64,2}; solveKey::Symbol=:default) where P
setVal!(v, val, solveKey=solveKey)
setBW!(v, bw, solveKey=solveKey)
nothing
end
function setVal!(vd::VariableNodeData, val::AbstractVector{P}, bw::Vector{Float64}; solveKey::Symbol=:default) where P <:AbstractVector
function setVal!(vd::VariableNodeData, val::AbstractVector{P}, bw::Vector{Float64}) where P
setVal!(vd, val, reshape(bw,length(bw),1))
nothing
end
function setVal!(v::DFGVariable, val::AbstractVector{P}, bw::Vector{Float64}; solveKey::Symbol=:default) where P <:AbstractVector
function setVal!(v::DFGVariable, val::AbstractVector{P}, bw::Vector{Float64}; solveKey::Symbol=:default) where P
setVal!(getSolverData(v, solveKey=solveKey), val, bw)
nothing
end
function setVal!(dfg::AbstractDFG, sym::Symbol, val::AbstractVector{P}; solveKey::Symbol=:default) where P <:AbstractVector
function setVal!(dfg::AbstractDFG, sym::Symbol, val::AbstractVector{P}; solveKey::Symbol=:default) where P
setVal!(getVariable(dfg, sym), val, solveKey=solveKey)
end

Expand All @@ -128,7 +128,7 @@ function setValKDE!(vd::VariableNodeData,
pts::AbstractVector{P},
bws::Vector{Float64},
setinit::Bool=true,
inferdim::Float64=0.0 ) where P <:AbstractVector
inferdim::Float64=0.0 ) where P
#
setVal!(vd, pts, bws) # BUG ...al!(., val, . ) ## TODO -- this can be a little faster
setinit ? (vd.initialized = true) : nothing
Expand All @@ -139,7 +139,7 @@ end
function setValKDE!(vd::VariableNodeData,
val::AbstractVector{P},
setinit::Bool=true,
inferdim::Real=0.0 ) where P <:AbstractVector
inferdim::Real=0.0 ) where P
# recover variableType information
varType = getVariableType(vd)
p = AMP.manikde!(val, varType)
Expand All @@ -152,7 +152,7 @@ function setValKDE!(v::DFGVariable,
bws::Array{<:Real,2},
setinit::Bool=true,
inferdim::Float64=0;
solveKey::Symbol=:default) where P <:AbstractVector
solveKey::Symbol=:default) where P
# recover variableType information
setValKDE!(getSolverData(v, solveKey), val, bws[:,1], setinit, inferdim )

Expand All @@ -163,7 +163,7 @@ function setValKDE!(v::DFGVariable,
val::AbstractVector{P},
setinit::Bool=true,
inferdim::Float64=0.0;
solveKey::Symbol=:default) where P <:AbstractVector
solveKey::Symbol=:default) where P
# recover variableType information
setValKDE!(getSolverData(v, solveKey),val, setinit, inferdim )
nothing
Expand Down Expand Up @@ -486,6 +486,18 @@ function addVariable!(dfg::AbstractDFG,
return v
end

function _resizePointsVector!(vecP::AbstractVector{P}, mkd::ManifoldKernelDensity, N::Int) where P
#
pN = length(vecP)
resize!(vecP, N)
for j in pN:N
smp = AMP.sample(mkd, 1)[1]
# @show j, smp, typeof(smp), typeof(vecP[j])
vecP[j] = smp[1]
end

vecP
end


"""
Expand All @@ -503,47 +515,34 @@ Notes
- for initialization, solveFor = Nothing.
- `P = getPointType(<:InferenceVariable)`
"""
function prepareparamsarray!( ARR::AbstractVector{P},
function prepareparamsarray!( ARR::AbstractVector{<:AbstractVector{P}},
Xi::Vector{<:DFGVariable},
solvefor::Union{Nothing, Symbol},
N::Int=0;
solveKey::Symbol=:default ) where P <: AbstractVector
solveKey::Symbol=:default ) where P
#
LEN = Int[]
maxlen = N # FIXME see #105
count = 0
sfidx = 0

for xi in Xi
push!(ARR, getVal(xi, solveKey=solveKey))
vecP = getVal(xi, solveKey=solveKey)
push!(ARR, vecP)
LEN = length.(ARR)
maxlen = maximum(LEN)
# @show len = size(ARR[end], 2)
# push!(LEN, len)
# if len > maxlen
# maxlen = len
# end
count += 1
if xi.label == solvefor
sfidx = count #xi.index
end
end

# resample variables with too few kernels
# resample variables with too few kernels (manifolds points)
SAMP = LEN .< maxlen
for i in 1:count
if SAMP[i]
Pr = getBelief(Xi[i], solveKey)
resize!(ARR[i], maxlen)
for j in 1:maxlen
smp = AMP.sample(Pr, 1)[1]
arr_i = ARR[i]
if isdefined(arr_i, j)
arr_i[j][:] = smp[:]
else
arr_i[j] = smp[:]
end
end
_resizePointsVector!(ARR[i], Pr, maxlen)
end
end

Expand Down Expand Up @@ -617,17 +616,16 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},
inflation::Real=0.0 ) where {T <: FunctorInferenceType}
#
pttypes = getVariableType.(Xi) .|> getPointType
sametype = 0 < length(pttypes) ? all( pttypes[1] .== pttypes ) : true
P_type = 0 < length(pttypes) ? Vector{pttypes[1]} : Vector{Float64}
@assert sametype "Current implementation only allows for same point type: $pttypes"
ARR = Vector{P_type}()
PointType = 0 < length(pttypes) ? pttypes[1] : Vector{Float64}
ARR = Vector{Vector{PointType}}()
maxlen, sfidx, mani = prepareparamsarray!(ARR, Xi, nothing, 0) # Nothing for init.
# fldnms = fieldnames(T) # typeof(usrfnc)

# standard factor metadata
sflbl = 0==length(Xi) ? :null : getLabel(Xi[end])
fmd = FactorMetadata(Xi, getLabel.(Xi), ARR, sflbl, nothing)
cf = CalcFactor( usrfnc, fmd, 0, 1, (Vector{Vector{Float64}}(),), ARR)
# guess measurement points type
MeasType = Vector{Float64} # FIXME use `usrfnc` to get this information instead
cf = CalcFactor( usrfnc, fmd, 0, 1, (Vector{MeasType}(),), ARR)

zdim = calcZDim(cf)
# zdim = T != GenericMarginal ? size(getSample(usrfnc, 2)[1],1) : 0
Expand All @@ -643,7 +641,7 @@ function prepgenericconvolution(Xi::Vector{<:DFGVariable},

ccw = CommonConvWrapper(
usrfnc,
P_type(),
PointType[],
zdim,
ARR,
fmd,
Expand Down Expand Up @@ -984,7 +982,7 @@ end
function initManual!( dfg::AbstractDFG,
sym::Symbol,
pts::AbstractVector{P},
solveKey::Symbol=:default ) where {P <: AbstractVector}
solveKey::Symbol=:default ) where {P}
#
var = getVariable(dfg, sym)
pp = manikde!(pts, getManifold(var))
Expand Down
7 changes: 1 addition & 6 deletions src/Factors/LinearRelative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@ LinearRelative(nm::Distributions.ContinuousUnivariateDistribution) = LinearRelat
LinearRelative(nm::MvNormal) = LinearRelative{length(nm.μ), typeof(nm)}(nm)
LinearRelative(nm::Union{<:BallTreeDensity,<:ManifoldKernelDensity}) = LinearRelative{Ndim(nm), typeof(nm)}(nm)

# getManifold(::InstanceType{LinearRelative{N}}) where {N} = Euclidean(N)
# getManifolds(::T) where {T <: LinearRelative} = convert(Tuple, getManifold(T))
# getManifolds(::Type{<:T}) where {T <: LinearRelative} = convert(Tuple, getManifold(T))
# getManifolds(fctType::Type{LinearRelative}) = getManifolds(getDomain(fctType))

getManifold(::InstanceType{LinearRelative{N}}) where N = ContinuousEuclid{N}
getManifold(::InstanceType{LinearRelative{N}}) where N = getManifold(ContinuousEuclid{N})

# TODO standardize
getDimension(::InstanceType{LinearRelative{N}}) where {N} = N
Expand Down
2 changes: 1 addition & 1 deletion src/ParametricCSMFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel=Logging.Info)
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
#Update subfg variables
vnd.val .= val.val
vnd.val[1] .= val.val
vnd.bw .= val.cov
end
else
Expand Down
2 changes: 1 addition & 1 deletion src/ParametricUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ function solveConditionalsParametric(fg::AbstractDFG,
flatvar = FlatVariables(fg, varIds)

for vId in varIds
flatvar[vId] = getVariableSolverData(fg, vId, solvekey).val[:,1]
flatvar[vId] = getVariableSolverData(fg, vId, solvekey).val[1][:]
end
initValues = flatvar.X

Expand Down
2 changes: 1 addition & 1 deletion src/TreeMessageUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ end


function generateMsgPrior(belief_::TreeBelief, ::NonparametricMessage)
kdePr = manikde!(belief_.val, belief_.bw[:,1], getManifold(belief_.variableType))
kdePr = manikde!(getManifold(belief_.variableType), belief_.val, belief_.bw[:,1])
MsgPrior(kdePr, belief_.inferdim)
end

Expand Down
Loading

0 comments on commit 4d7cabc

Please sign in to comment.