Skip to content

Commit

Permalink
Merge pull request #187 from biaslab/nlpm
Browse files Browse the repository at this point in the history
Refactor Nonlinear dimensionality specification
  • Loading branch information
ThijsvdLaar authored Dec 15, 2021
2 parents b10b720 + 9fd2cbb commit fc11951
Show file tree
Hide file tree
Showing 48 changed files with 456 additions and 1,221 deletions.
606 changes: 10 additions & 596 deletions demo/bootstrap_particle_filter.ipynb

Large diffs are not rendered by default.

31 changes: 15 additions & 16 deletions demo/nonlinear_kalman_filter.ipynb

Large diffs are not rendered by default.

80 changes: 9 additions & 71 deletions demo/nonlinear_online_estimation.ipynb

Large diffs are not rendered by default.

41 changes: 16 additions & 25 deletions demo/variational_laplace_and_sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,7 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Precompiling ForneyLab [9fc3f58a-c2cc-5bff-9419-6a294fefdca9]\n",
"└ @ Base loading.jl:1273\n"
]
}
],
"outputs": [],
"source": [
"using ForneyLab, LinearAlgebra\n",
"\n",
Expand Down Expand Up @@ -206,7 +197,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"The marginal for l is a ProbabilityDistribution{Univariate,SampleList} with mean 0.636 and variance 0.014\n"
"The marginal for l is a ProbabilityDistribution{Univariate, SampleList} with mean 0.611 and variance 0.014\n"
]
}
],
Expand Down Expand Up @@ -320,7 +311,7 @@
{
"data": {
"text/plain": [
"𝒩(xi=1.40, w=2.33)\n"
"𝒩(xi=1.40, w=2.32)\n"
]
},
"execution_count": 14,
Expand All @@ -341,7 +332,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Free energy per iteration: 2.037, 1.952, 1.957, 1.942, 1.943"
"Free energy per iteration: 2.041, 1.946, 1.938, 1.943, 1.939"
]
}
],
Expand Down Expand Up @@ -437,7 +428,7 @@
{
"data": {
"text/plain": [
"Dir(a=[2.35, 5.15, 3.20])\n"
"Dir(a=[2.34, 5.16, 3.20])\n"
]
},
"execution_count": 19,
Expand All @@ -458,10 +449,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"The marginal for x is a ProbabilityDistribution{Univariate,SampleList} with mean vector entries\n",
" [1] = 0.227738\n",
" [2] = 0.772048\n",
" [3] = 0.000214287\n"
"The marginal for x is a ProbabilityDistribution{Univariate, SampleList} with mean vector entries\n",
" [1] = 0.225319\n",
" [2] = 0.774444\n",
" [3] = 0.000237294\n"
]
}
],
Expand All @@ -480,7 +471,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Free energy per iteration: 2.463, 2.459, 2.468, 2.416, 2.449"
"Free energy per iteration: 2.504, 2.43, 2.44, 2.426, 2.492"
]
}
],
Expand Down Expand Up @@ -553,7 +544,7 @@
{
"data": {
"text/plain": [
"𝒩(xi=0.91, w=0.99)\n"
"𝒩(xi=0.77, w=0.85)\n"
]
},
"execution_count": 25,
Expand All @@ -573,7 +564,7 @@
{
"data": {
"text/plain": [
"𝒩(xi=6.75, w=3.63)\n"
"𝒩(xi=6.10, w=3.29)\n"
]
},
"execution_count": 26,
Expand All @@ -594,7 +585,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"The marginal for m is a ProbabilityDistribution{Univariate,SampleList} with mean 4.229 and variance 0.979\n"
"The marginal for m is a ProbabilityDistribution{Univariate, SampleList} with mean 4.175 and variance 0.902\n"
]
}
],
Expand All @@ -617,15 +608,15 @@
"lastKernelId": null
},
"kernelspec": {
"display_name": "Julia 1.3.0",
"display_name": "Julia 1.6.4",
"language": "julia",
"name": "julia-1.3"
"name": "julia-1.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.3.0"
"version": "1.6.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion src/engines/julia/generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ function vagueSourceCode(entry::ScheduleEntry)
family_code = removePrefix(entry.family)
dims = entry.dimensionality
if dims == ()
vague_code = "vague($family_code)"
vague_code = "vague($family_code)" # Default
else
vague_code = "vague($family_code, $dims)"
end
Expand Down
12 changes: 6 additions & 6 deletions src/engines/julia/update_rules/gaussian_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function ruleVBGaussianMeanPrecisionW( dist_out::ProbabilityDistribution{Multiv
(m_mean, v_mean) = unsafeMeanCov(dist_mean)
(m_out, v_out) = unsafeMeanCov(dist_out)

Message(MatrixVariate, Wishart, v=cholinv( v_mean + v_out + (m_mean - m_out)*(m_mean - m_out)' ), nu=dims(dist_out) + 2.0)
Message(MatrixVariate, Wishart, v=cholinv( v_mean + v_out + (m_mean - m_out)*(m_mean - m_out)' ), nu=dims(dist_out)[1] + 2.0)
end

ruleVBGaussianMeanPrecisionOut( dist_out::Any,
Expand All @@ -63,21 +63,21 @@ ruleVBGaussianMeanPrecisionOut( dist_out::Any,
Message(V, GaussianMeanPrecision, m=unsafeMean(dist_mean), w=unsafeMean(dist_prec))

ruleSVBGaussianMeanPrecisionOutVGD(dist_out::Any,
msg_mean::Message{F, V},
dist_prec::ProbabilityDistribution) where{F<:Gaussian, V<:VariateType} =
msg_mean::Message{<:Gaussian, V},
dist_prec::ProbabilityDistribution) where V<:VariateType =
Message(V, GaussianMeanVariance, m=unsafeMean(msg_mean.dist), v=unsafeCov(msg_mean.dist) + cholinv(unsafeMean(dist_prec)))

function ruleSVBGaussianMeanPrecisionW(
dist_out_mean::ProbabilityDistribution{Multivariate, F},
dist_prec::Any) where F<:Gaussian

joint_dims = dims(dist_out_mean)
joint_d = dims(dist_out_mean)[1]
d_out_mean = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, dist_out_mean)
(m, V) = unsafeMeanCov(d_out_mean)
if joint_dims == 2
if joint_d == 2
return Message(Univariate, Gamma, a=1.5, b=0.5*(V[1,1] - V[1,2] - V[2,1] + V[2,2] + (m[1] - m[2])^2))
else
d = Int64(joint_dims/2)
d = Int64(joint_d/2)
return Message(MatrixVariate, Wishart, v=cholinv( V[1:d,1:d] - V[1:d,d+1:end] - V[d+1:end, 1:d] + V[d+1:end,d+1:end] + (m[1:d] - m[d+1:end])*(m[1:d] - m[d+1:end])' ), nu=d + 2.0)
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/engines/julia/update_rules/gaussian_mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function ruleVBGaussianMixtureW(dist_out::ProbabilityDistribution,
(m_mean_k, v_mean_k) = unsafeMeanCov(dist_means[k])
(m_out, v_out) = unsafeMeanCov(dist_out)
z_bar = unsafeMeanVector(dist_switch)
d = dims(dist_means[1])
d = dims(dist_means[1])[1]

return Message(MatrixVariate, Wishart,
nu = 1.0 + z_bar[k] + d,
Expand Down Expand Up @@ -123,7 +123,7 @@ function ruleVBGaussianMixtureOut( dist_out::Any,
dist_means = collect(dist_factors[1:2:end])
dist_precs = collect(dist_factors[2:2:end])
z_bar = unsafeMeanVector(dist_switch)
d = dims(dist_means[1])
d = dims(dist_means[1])[1]

w = Diagonal(zeros(d))
xi = zeros(d)
Expand Down
1 change: 0 additions & 1 deletion src/engines/julia/update_rules/multiplication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ function ruleSPMultiplicationIn1GNP(msg_out::Message{F, Multivariate},

dist_a_matr = convert(ProbabilityDistribution{MatrixVariate, PointMass}, msg_a.dist)
msg_in1_mult = ruleSPMultiplicationIn1GNP(msg_out, nothing, Message(dist_a_matr))
(dims(msg_in1_mult.dist) == 1) || error("Implicit conversion to Univariate failed for $(msg_in1_mult.dist)")

return Message(Univariate, GaussianWeightedMeanPrecision, xi=msg_in1_mult.dist.params[:xi][1], w=msg_in1_mult.dist.params[:w][1,1])
end
81 changes: 47 additions & 34 deletions src/engines/julia/update_rules/nonlinear_extended.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,46 @@ ruleSPNonlinearEInGX,
ruleMNonlinearEInGX

"""
Concatenate a vector of vectors and return with original dimensions (for splitting)
Concatenate a vector (of vectors and floats) and return with original dimensions (for splitting)
"""
function concatenate(xs::Vector{Vector{Float64}})
ds = [length(x_k) for x_k in xs] # Extract dimensions
function concatenate(xs::Vector)
ds = [size(x_k) for x_k in xs] # Extract dimensions
x = vcat(xs...)

return (x, ds)
end

"""
Return local linearization of g around expansion point x_hat
for Nonlinear node with single input interface
"""
function localLinearization(V::Type{Univariate}, g::Function, x_hat::Float64)
function localLinearizationSingleIn(g::Function, x_hat::Float64)
a = ForwardDiff.derivative(g, x_hat)
b = g(x_hat) - a*x_hat

return (a, b)
end

function localLinearization(V::Type{Multivariate}, g::Function, x_hat::Vector{Float64})
function localLinearizationSingleIn(g::Function, x_hat::Vector{Float64})
A = ForwardDiff.jacobian(g, x_hat)
b = g(x_hat) - A*x_hat

return (A, b)
end

function localLinearization(V::Type{Univariate}, g::Function, x_hat::Vector{Float64})
"""
Return local linearization of g around expansion point x_hat
for Nonlinear node with multiple input interfaces
"""
function localLinearizationMultiIn(g::Function, x_hat::Vector{Float64})
g_unpacked(x::Vector) = g(x...)
A = ForwardDiff.gradient(g_unpacked, x_hat)'
b = g(x_hat...) - A*x_hat

return (A, b)
end

function localLinearization(V::Type{Multivariate}, g::Function, x_hat::Vector{Vector{Float64}})
function localLinearizationMultiIn(g::Function, x_hat::Vector{Vector{Float64}})
(x_cat, ds) = concatenate(x_hat)
g_unpacked(x::Vector) = g(split(x, ds)...)
A = ForwardDiff.jacobian(g_unpacked, x_cat)
Expand All @@ -57,74 +62,82 @@ end
# Forward rule
function ruleSPNonlinearEOutNG(g::Function,
msg_out::Nothing,
msg_in1::Message{<:Gaussian, V}) where V<:VariateType
msg_in1::Message{<:Gaussian})

(m_in1, V_in1) = unsafeMeanCov(msg_in1.dist)
(A, b) = localLinearization(V, g, m_in1)
(A, b) = localLinearizationSingleIn(g, m_in1)
m = A*m_in1 + b
V = A*V_in1*A'

return Message(GaussianMeanVariance, A*m_in1 + b, A*V_in1*A') # Automatically determine VariateType
return Message(variateType(m), GaussianMeanVariance, m=m, v=V)
end

# Multi-argument forward rule
function ruleSPNonlinearEOutNGX(g::Function, # Needs to be in front of Vararg
msg_out::Nothing,
msgs_in::Vararg{Message{<:Gaussian, V}}) where V<:VariateType
msgs_in::Vararg{Message{<:Gaussian}})

(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Returns arrays with individual means and covariances
(A, b) = localLinearization(V, g, ms_fw_in)
(A, b) = localLinearizationMultiIn(g, ms_fw_in)
(m_fw_in, V_fw_in, _) = concatenateGaussianMV(ms_fw_in, Vs_fw_in)
m = A*m_fw_in + b
V = A*V_fw_in*A'

return Message(GaussianMeanVariance, A*m_fw_in + b, A*V_fw_in*A') # Automatically determine VariateType
return Message(variateType(m), GaussianMeanVariance, m=m, v=V)
end

# Backward rule with given inverse
function ruleSPNonlinearEIn1GG(g::Function,
g_inv::Function,
msg_out::Message{<:Gaussian, V},
msg_in1::Nothing) where V<:VariateType
msg_out::Message{<:Gaussian},
msg_in1::Nothing)

(m_out, V_out) = unsafeMeanCov(msg_out.dist)
(A, b) = localLinearization(V, g_inv, m_out)
(A, b) = localLinearizationSingleIn(g_inv, m_out)
m = A*m_out + b
V = A*V_out*A'

return Message(GaussianMeanVariance, A*m_out + b, A*V_out*A') # Automatically determine VariateType
return Message(variateType(m), GaussianMeanVariance, m=m, v=V)
end

# Multi-argument backward rule with given inverse
function ruleSPNonlinearEInGX(g::Function, # Needs to be in front of Vararg
g_inv::Function,
msg_out::Message{<:Gaussian},
msgs_in::Vararg{Union{Message{<:Gaussian, V}, Nothing}}) where V<:VariateType
msgs_in::Vararg{Union{Message{<:Gaussian}, Nothing}})

(ms, Vs) = collectStatistics(msg_out, msgs_in...) # Returns arrays with individual means and covariances
(A, b) = localLinearization(V, g_inv, ms)
(A, b) = localLinearizationMultiIn(g_inv, ms)
(mc, Vc) = concatenateGaussianMV(ms, Vs)
m = A*mc + b
V = A*Vc*A'

return Message(V, GaussianMeanVariance, m=A*mc, v=A*Vc*A')
return Message(variateType(m), GaussianMeanVariance, m=m, v=V)
end

# Backward rule with unknown inverse
function ruleSPNonlinearEIn1GG(g::Function,
msg_out::Message{<:Gaussian},
msg_in1::Message{<:Gaussian, V}) where V<:VariateType
msg_in1::Message{<:Gaussian})

m_in1 = unsafeMean(msg_in1.dist)
d_out = convert(ProbabilityDistribution{V, GaussianMeanPrecision}, msg_out.dist)
m_out = d_out.params[:m]
W_out = d_out.params[:w]
(A, b) = localLinearization(V, g, m_in1)
(m_out, W_out) = unsafeMeanPrecision(msg_out.dist)
(A, b) = localLinearizationSingleIn(g, m_in1)
xi = A'*W_out*(m_out - b)
W = A'*W_out*A

return Message(V, GaussianWeightedMeanPrecision, xi=A'*W_out*(m_out - b), w=A'*W_out*A)
return Message(variateType(xi), GaussianWeightedMeanPrecision, xi=xi, w=W)
end

# Multi-argument backward rule with unknown inverse
function ruleSPNonlinearEInGX(g::Function,
inx::Int64, # Index of inbound interface inx
msg_out::Message{<:Gaussian},
msgs_in::Vararg{Message{<:Gaussian, V}}) where V<:VariateType
msgs_in::Vararg{Message{<:Gaussian}})

# Approximate joint inbounds
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Returns arrays with individual means and covariances
(A, b) = localLinearization(V, g, ms_fw_in)
(A, b) = localLinearizationMultiIn(g, ms_fw_in)

(m_fw_in, V_fw_in, ds) = concatenateGaussianMV(ms_fw_in, Vs_fw_in)
m_fw_out = A*m_fw_in + b
Expand All @@ -136,27 +149,27 @@ function ruleSPNonlinearEInGX(g::Function,
(m_in, V_in) = smoothRTS(m_fw_out, V_fw_out, C_fw, m_fw_in, V_fw_in, m_bw_out, V_bw_out)

# Marginalize joint belief on in's
(m_inx, V_inx) = marginalizeGaussianMV(V, m_in, V_in, ds, inx) # Marginalization is overloaded on VariateType V
(m_inx, V_inx) = marginalizeGaussianMV(m_in, V_in, ds, inx)
W_inx = cholinv(V_inx) # Convert to canonical statistics
xi_inx = W_inx*m_inx

# Divide marginal on inx by forward message
(xi_fw_inx, W_fw_inx) = unsafeWeightedMeanPrecision(msgs_in[inx].dist)
xi_bw_inx = xi_inx - xi_fw_inx
W_bw_inx = W_inx - W_fw_inx # Note: subtraction might lead to posdef inconsistencies
W_bw_inx = W_inx - W_fw_inx # Note: subtraction might lead to posdef violations

return Message(V, GaussianWeightedMeanPrecision, xi=xi_bw_inx, w=W_bw_inx)
return Message(variateType(xi_bw_inx), GaussianWeightedMeanPrecision, xi=xi_bw_inx, w=W_bw_inx)
end

function ruleMNonlinearEInGX(g::Function,
msg_out::Message{<:Gaussian},
msgs_in::Vararg{Message{<:Gaussian, V}}) where V<:VariateType
msgs_in::Vararg{Message{<:Gaussian}})

# Approximate joint inbounds
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Returns arrays with individual means and covariances
(A, b) = localLinearization(V, g, ms_fw_in)
(A, b) = localLinearizationMultiIn(g, ms_fw_in)

(m_fw_in, V_fw_in, ds) = concatenateGaussianMV(ms_fw_in, Vs_fw_in)
(m_fw_in, V_fw_in, _) = concatenateGaussianMV(ms_fw_in, Vs_fw_in)
m_fw_out = A*m_fw_in + b
V_fw_out = A*V_fw_in*A'
C_fw = V_fw_in*A'
Expand Down
Loading

0 comments on commit fc11951

Please sign in to comment.