Skip to content

Commit

Permalink
Merge pull request #118 from JuliaTrustworthyAI/116-make-plot-and-pre…
Browse files Browse the repository at this point in the history
…dict-consistent

116 make plot and predict consistent
  • Loading branch information
pat-alt authored Sep 3, 2024
2 parents f7bf0f4 + e88f54a commit 3700cb9
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 51 deletions.
66 changes: 39 additions & 27 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.3"
julia_version = "1.10.5"
manifest_format = "2.0"
project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca"

Expand Down Expand Up @@ -85,9 +85,9 @@ version = "3.5.1+1"

[[deps.ArrayInterface]]
deps = ["Adapt", "LinearAlgebra"]
git-tree-sha1 = "f54c23a5d304fb87110de62bace7777d59088c34"
git-tree-sha1 = "3640d077b6dafd64ceb8fd5c1ec76f7ca53bcf76"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.15.0"
version = "7.16.0"

[deps.ArrayInterface.extensions]
ArrayInterfaceBandedMatricesExt = "BandedMatrices"
Expand Down Expand Up @@ -209,9 +209,9 @@ version = "0.9.2+0"

[[deps.CUDA_Runtime_Discovery]]
deps = ["Libdl"]
git-tree-sha1 = "f3b237289a5a77c759b2dd5d4c2ff641d67c4030"
git-tree-sha1 = "33576c7c1b2500f8e7e6baa082e04563203b3a45"
uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
version = "0.3.4"
version = "0.3.5"

[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
Expand Down Expand Up @@ -359,17 +359,18 @@ uuid = "98bfc277-1877-43dc-819b-a3e38c30242f"
version = "0.1.13"

[[deps.ConstructionBase]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "a33b7ced222c6165f624a3f2b55945fac5a598d9"
git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
version = "1.5.7"
version = "1.5.8"

[deps.ConstructionBase.extensions]
ConstructionBaseIntervalSetsExt = "IntervalSets"
ConstructionBaseLinearAlgebraExt = "LinearAlgebra"
ConstructionBaseStaticArraysExt = "StaticArrays"

[deps.ConstructionBase.weakdeps]
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[[deps.ContextVariablesX]]
Expand Down Expand Up @@ -569,19 +570,24 @@ uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
version = "1.16.3"

[[deps.FilePathsBase]]
deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"]
git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa"
deps = ["Compat", "Dates"]
git-tree-sha1 = "7878ff7172a8e6beedd1dea14bd27c3c6340d361"
uuid = "48062228-2e41-5def-b9a4-89aafe57970f"
version = "0.9.21"
version = "0.9.22"
weakdeps = ["Mmap", "Test"]

[deps.FilePathsBase.extensions]
FilePathsBaseMmapExt = "Mmap"
FilePathsBaseTestExt = "Test"

[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"

[[deps.FillArrays]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "fd0002c0b5362d7eb952450ad5eb742443340d6e"
git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.12.0"
version = "1.13.0"
weakdeps = ["PDMats", "SparseArrays", "Statistics"]

[deps.FillArrays.extensions]
Expand Down Expand Up @@ -841,10 +847,10 @@ uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"

[[deps.JLD2]]
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"]
git-tree-sha1 = "67d4690d32c22e28818a434b293a374cc78473d3"
deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"]
git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28"
uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
version = "0.4.51"
version = "0.4.53"

[[deps.JLFzf]]
deps = ["Pipe", "REPL", "Random", "fzf_jll"]
Expand All @@ -854,9 +860,9 @@ version = "0.1.8"

[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.5.0"
version = "1.6.0"

[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
Expand Down Expand Up @@ -884,9 +890,9 @@ version = "0.2.4"

[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "35ceea58aa34ad08b1ae00a52622c62d1cfb8ce2"
git-tree-sha1 = "cb1cff88ef2f3a157cbad75bbe6b229e1975e498"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.9.24"
version = "0.9.25"

[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
Expand Down Expand Up @@ -1444,9 +1450,9 @@ version = "1.4.1"

[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"]
git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf"
git-tree-sha1 = "45470145863035bb124ca51b320ed35d071cc6c2"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.40.5"
version = "1.40.8"

[deps.Plots.extensions]
FileIOExt = "FileIO"
Expand Down Expand Up @@ -1514,9 +1520,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.10.2"

[[deps.PtrArrays]]
git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f"
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
version = "1.2.0"
version = "1.2.1"

[[deps.Qt6Base_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"]
Expand Down Expand Up @@ -1544,9 +1550,15 @@ version = "6.7.1+1"

[[deps.QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "e237232771fdafbae3db5c31275303e056afaa9f"
git-tree-sha1 = "1d587203cf851a51bf1ea31ad7ff89eff8d625ea"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.10.1"
version = "2.11.0"

[deps.QuadGK.extensions]
QuadGKEnzymeExt = "Enzyme"

[deps.QuadGK.weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[[deps.RData]]
deps = ["CategoricalArrays", "CodecZlib", "DataFrames", "Dates", "FileIO", "Requires", "TimeZones", "Unicode"]
Expand Down Expand Up @@ -2274,7 +2286,7 @@ version = "0.15.2+0"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"
version = "5.11.0+0"

[[deps.libdecor_jll]]
deps = ["Artifacts", "Dbus_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pango_jll", "Wayland_jll", "xkbcommon_jll"]
Expand Down
65 changes: 45 additions & 20 deletions src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,24 @@ function has_softmax_or_sigmoid_final_layer(model::Flux.Chain)
return has_finaliser
end

"""
@doc raw"""
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)
Compute the functional variance for the GLM predictive. Dispatches to the appropriate method based on the Hessian structure.
Computes the functional variance for the GLM predictive as `map(j -> (j' * Σ * j), eachrow(𝐉))` which is a (output x output) predictive covariance matrix. Formally, we have ``{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}`` where ``\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta`` is the Jacobian evaluated at the MAP estimate.
Dispatches to the appropriate method based on the Hessian structure.
"""
function functional_variance(la, 𝐉)
return functional_variance(la, la.est_params.hessian_structure, 𝐉)
end

"""
@doc raw"""
glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
Computes the linearized GLM predictive.
Computes the linearized GLM predictive from neural network with a Laplace approximation to the posterior ``p(\theta|\mathcal{D})=\mathcal{N}(\hat\theta,\Sigma)``.
This is the distribution on network outputs given by ``p(f(x)|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta})``.
For the Bayesian predictive distribution, see [`predict`](@ref).
# Arguments
Expand All @@ -49,7 +54,7 @@ Computes the linearized GLM predictive.
# Examples
```julia-repl
```julia
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
Expand All @@ -58,42 +63,55 @@ nn = Chain(Dense(2,1))
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
glm_predictive_distribution(la, hcat(x...))
```
"""
function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
𝐉, fμ = Curvature.jacobians(la.est_params.curvature, X)
= reshape(fμ, Flux.outputsize(la.model, size(X)))
fvar = functional_variance(la, 𝐉)
fvar = reshape(fvar, size(fμ)...)
fstd = sqrt.(fvar)
normal_distr = [Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 2)]
normal_distr = [Normal(fμ[i], fstd[i]) for i in axes(fμ, 2)]
return (normal_distr, fμ, fvar)
end

"""
predict(la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true)
@doc raw"""
predict(
la::AbstractLaplace,
X::AbstractArray;
link_approx=:probit,
predict_proba::Bool=true,
ret_distr::Bool=false,
)
Computes predictions from Bayesian neural network.
Computes the Bayesian predictivie distribution from a neural network with a Laplace approximation to the posterior ``p(\theta | \mathcal{D}) = \mathcal{N}(\hat\theta, \Sigma)``.
# Arguments
- `la::AbstractLaplace`: A Laplace object.
- `X::AbstractArray`: Input data.
- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`.
- `predict_proba::Bool=true`: If `true` (default) apply a sigmoid or a softmax function to the output of the Flux model.
- `return_distr::Bool=false`: if `false` (default), the function output either the direct output of the chain or pseudo-probabilities (if predict_proba= true).
- `return_distr::Bool=false`: if `false` (default), the function outputs either the direct output of the chain or pseudo-probabilities (if `predict_proba=true`).
if `true` predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.
# Returns
For classification tasks, LaplaceRedux provides different options:
if ret_distr is false:
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
if ret_distr is true:
- a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.
For classification tasks:
1. If `ret_distr` is `false`, `predict` returns `fμ`, i.e. the mean of the predictive distribution, which corresponds to the MAP estimate if the link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
2. If `ret_distr` is `true`, `predict` returns a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.
For regression tasks:
- `normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution.
1. If `ret_distr` is `false`, `predict` returns the mean and the variance of the predictive distribution. The output shape is column-major as in Flux.
2. If `ret_distr` is `true`, `predict` returns the predictive posterior distribution, namely:
``p(y|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta} + \sigma^2 \mathbf{I})``
# Examples
```julia-repl
```julia
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
Expand All @@ -111,15 +129,22 @@ function predict(
predict_proba::Bool=true,
ret_distr::Bool=false,
)
normal_distr, fμ, fvar = glm_predictive_distribution(la, X)
_, fμ, fvar = glm_predictive_distribution(la, X)

# Regression:
if la.likelihood == :regression

# Add observational noise:
pred_var = fvar .+ la.prior.σ^2
fstd = sqrt.(pred_var)
pred_dist = [Normal(fμ[i], fstd[i]) for i in axes(fμ, 2)]

if ret_distr
return reshape(normal_distr, (:, 1))
return reshape(pred_dist, (:, 1))
else
return fμ, fvar
return fμ, pred_var
end

end

# Classification:
Expand Down
6 changes: 3 additions & 3 deletions src/full.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ function _fit!(
return la.posterior.n_data = n_data
end

"""
functional_variance(la::Laplace,𝐉)
@doc raw"""
functional_variance(la::Laplace, hessian_structure::FullHessian, 𝐉)
Compute the linearized GLM predictive variance as `𝐉ₙΣ𝐉ₙ'` where `𝐉=∇f(x;θ)|θ̂` is the Jacobian evaluated at the MAP estimate and `Σ = P⁻¹`.
Computes the functional variance for the GLM predictive as `map(j -> (j' * Σ * j), eachrow(𝐉))` which is a (output x output) predictive covariance matrix. Formally, we have ``{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}`` where ``\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta`` is the Jacobian evaluated at the MAP estimate.
"""
function functional_variance(la::Laplace, hessian_structure::FullHessian, 𝐉)
Σ = posterior_covariance(la)
Expand Down
2 changes: 1 addition & 1 deletion src/kronecker/kron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function _fit!(
end

"""
functional_variance(la::Laplace, hessian_structure::KronHessian, 𝐉::Matrix)
functional_variance(la::Laplace, hessian_structure::KronHessian, 𝐉::Matrix)
Compute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ,
where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output `𝐉=∇f(x;θ)|θ̂`.
Expand Down

2 comments on commit 3700cb9

@pat-alt
Copy link
Member Author

@pat-alt pat-alt commented on 3700cb9 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 1.0.2 already exists

Please sign in to comment.