From e907b5ebfe41e2d6e524d968d22cbe7819a2ed31 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Tue, 26 Nov 2024 19:37:16 +0100 Subject: [PATCH 01/10] fixed scitype and added aliases --- docs/src/tutorials/regression.qmd | 2 +- src/direct_mlj.jl | 90 ++++++++++++++++++------------- 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/docs/src/tutorials/regression.qmd b/docs/src/tutorials/regression.qmd index 2faee5a..5059649 100644 --- a/docs/src/tutorials/regression.qmd +++ b/docs/src/tutorials/regression.qmd @@ -128,7 +128,7 @@ then we can plot the calibration plot of our neural model ```{julia} #| output: true -Calibration_Plot(la,y_test,vec(predicted_distributions);n_bins = 20) +calibration_plot(la,y_test,vec(predicted_distributions);n_bins = 20) ``` and compute the sharpness of the predictive distribution diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 153f6a5..9996efc 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -19,9 +19,9 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal)) backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64 = 1.0 - μ₀::Float64 = 0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + observational_noise::Float64 = 1.0 + prior_mean::Float64 = 0.0 + prior_precision_matrix::Union{AbstractMatrix,UniformScaling,Nothing} = nothing fit_prior_nsteps::Int = 100::(_ > 0) link_approx::Symbol = :probit::(_ in (:probit, :plugin)) end @@ -37,14 +37,32 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal)) backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64 = 1.0 - μ₀::Float64 = 0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing + observational_noise::Float64 = 1.0 + prior_mean::Float64 = 0.0 + prior_precision_matrix::Union{AbstractMatrix,UniformScaling,Nothing} = nothing fit_prior_nsteps::Int = 100::(_ > 0) end LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier} +# Aliases +const LM = LaplaceModels +function Base.getproperty(ce::LM, sym::Symbol) + sym = sym === :σ ? :observational_noise : sym + sym = sym === :μ₀ ? :prior_mean : sym + sym = sym === :P₀ ? :prior_precision_matrix : sym + return Base.getfield(ce, sym) +end +function Base.setproperty!(ce::LM, sym::Symbol, val) + sym = sym === :σ ? :observational_noise : sym + sym = sym === :μ₀ ? :prior_mean : sym + sym = sym === :P₀ ? :prior_precision_matrix : sym + return Base.setfield!(ce, sym, val) +end + + + + # for fit: function MMI.reformat(::LaplaceRegressor, X, y) return (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing)) @@ -193,9 +211,9 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y) subnetwork_indices=m.subnetwork_indices, hessian_structure=m.hessian_structure, backend=m.backend, - σ=m.σ, - μ₀=m.μ₀, - P₀=m.P₀, + σ=m.observational_noise, + μ₀=m.prior_mean, + P₀=m.prior_precision_matrix, ) if typeof(m) == LaplaceClassifier @@ -282,9 +300,9 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) subnetwork_indices=m.subnetwork_indices, hessian_structure=m.hessian_structure, backend=m.backend, - σ=m.σ, - μ₀=m.μ₀, - P₀=m.P₀, + σ=m.observational_noise, + μ₀=m.prior_mean, + P₀=m.prior_precision_matrix, ) if typeof(m) == LaplaceClassifier la.likelihood = :classification @@ -315,9 +333,9 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) :subnetwork_indices, :hessian_structure, :backend, - :σ, - :μ₀, - :P₀, + :observational_noise, + :prior_mean, + :prior_precision_matrix, ) println(" updating only the laplace optimization part") old_la = old_fitresult[1] @@ -329,9 +347,9 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) subnetwork_indices=m.subnetwork_indices, hessian_structure=m.hessian_structure, backend=m.backend, - σ=m.σ, - μ₀=m.μ₀, - P₀=m.P₀, + σ=m.observational_noise, + μ₀=m.prior_mean, + P₀=m.prior_precision_matrix, ) if typeof(m) == LaplaceClassifier la.likelihood = :classification @@ -452,10 +470,10 @@ end # Returns A named tuple containing: - - `μ`: The mean of the posterior distribution. + - `mean`: The mean of the posterior distribution. - `H`: The Hessian of the posterior distribution. - `P`: The precision matrix of the posterior distribution. - - `Σ`: The covariance matrix of the posterior distribution. + - `cov_matrix`: The covariance matrix of the posterior distribution. - `n_data`: The number of data points. - `n_params`: The number of parameters. - `n_out`: The number of outputs. @@ -466,10 +484,10 @@ function MMI.fitted_params(model::LaplaceModels, fitresult) la, decode = fitresult posterior = la.posterior return ( - μ=posterior.μ, + mean=posterior.μ, H=posterior.H, P=posterior.P, - Σ=posterior.Σ, + cov_matrix=posterior.Σ, n_data=posterior.n_data, n_params=posterior.n_params, n_out=posterior.n_out, @@ -517,7 +535,7 @@ function MMI.predict(m::LaplaceModels, fitresult, Xnew) means, variances = yhat # Create Normal distributions from the means and variances - return vec([Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]) + return vec([Normal(mean, sqrt(variance)) for (mean, variance) in zip(means, variances)]) else predictions = @@ -551,7 +569,7 @@ MMI.metadata_pkg( MLJBase.metadata_model( LaplaceClassifier; input_scitype=Union{ - AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types + AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Infinite}}, # matrix with mixed types MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types }, target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass @@ -562,8 +580,8 @@ MLJBase.metadata_model( MLJBase.metadata_model( LaplaceRegressor; input_scitype=Union{ - AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types - MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types + AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Infinite}}, # matrix with mixed types + MLJBase.Table(MLJBase.Finite, MLJBase.Infinite), # table with mixed types }, target_scitype=AbstractArray{MLJBase.Continuous}, supports_training_losses=true, @@ -626,11 +644,11 @@ Train the machine using `fit!(mach, rows=...)`. - `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `σ::Float64 = 1.0`: the standard deviation of the prior distribution. +- `observational_noise (alias σ)::Float64 = 1.0`: the standard deviation of the prior distribution. -- `μ₀::Float64 = 0.0`: the mean of the prior distribution. +- `prior_mean (alias μ₀)::Float64 = 0.0`: the mean of the prior distribution. -- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution. +- `prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution. - `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors. @@ -650,13 +668,13 @@ Train the machine using `fit!(mach, rows=...)`. The fields of `fitted_params(mach)` are: - - `μ`: The mean of the posterior distribution. + - `mean`: The mean of the posterior distribution. - `H`: The Hessian of the posterior distribution. - `P`: The precision matrix of the posterior distribution. - - `Σ`: The covariance matrix of the posterior distribution. + - `cov_matrix`: The covariance matrix of the posterior distribution. - `n_data`: The number of data points. @@ -766,11 +784,11 @@ Train the machine using `fit!(mach, rows=...)`. - `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `σ::Float64 = 1.0`: the standard deviation of the prior distribution. +- `observational_noise (alias σ)::Float64 = 1.0`: the standard deviation of the prior distribution. -- `μ₀::Float64 = 0.0`: the mean of the prior distribution. +- `prior_mean (alias μ₀)::Float64 = 0.0`: the mean of the prior distribution. -- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution. +- `prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution. - `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors. @@ -789,13 +807,13 @@ Train the machine using `fit!(mach, rows=...)`. The fields of `fitted_params(mach)` are: - - `μ`: The mean of the posterior distribution. + - `mean`: The mean of the posterior distribution. - `H`: The Hessian of the posterior distribution. - `P`: The precision matrix of the posterior distribution. - - `Σ`: The covariance matrix of the posterior distribution. + - `cov_matrix`: The covariance matrix of the posterior distribution. - `n_data`: The number of data points. From 7e3516fc8506581033ad1f5abf94f39769506d6e Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Tue, 26 Nov 2024 19:54:02 +0100 Subject: [PATCH 02/10] added info in changelog --- CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0958616..6068366 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.2.1]. + +## Version [2.0.0] - 2024-11-26 + + + +### Added + +- added support to MLJ [#126] [#134] + + ## Version [1.1.1] - 2024-09-12 ### Changed From 292304ee13c4924a7f54fe554833a30a4700f996 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Dec 2024 09:09:42 +0100 Subject: [PATCH 03/10] added tests for aliases --- test/direct_mlj_interface.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/direct_mlj_interface.jl b/test/direct_mlj_interface.jl index 6a0d642..f3bb228 100644 --- a/test/direct_mlj_interface.jl +++ b/test/direct_mlj_interface.jl @@ -14,6 +14,14 @@ cv = MLJBase.CV(; nfolds=3) flux_model = Chain(Dense(4, 10, relu), Dense(10, 10, relu), Dense(10, 1)) model = LaplaceRegressor(; model=flux_model, epochs=20) + # Aliases: + model.σ = model.observational_noise + model.μ₀ = model.prior_mean + model.P₀ = model.prior_precision_matrix + @test model.observational_noise == model.σ + @test model.prior_mean == model.μ₀ + @test model.prior_precision_matrix == model.P₀ + #testing more complex dataset X, y = MLJBase.make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1) #train, test = partition(eachindex(y), 0.7); # 70:30 split From 8a73fe253cbe1f1a1a13ce572f5b2c150066404d Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Dec 2024 09:30:10 +0100 Subject: [PATCH 04/10] addressing other open tasks --- .github/workflows/CI.yml | 1 - CHANGELOG.md | 8 ++++--- Project.toml | 2 +- src/baselaplace/core_struct.jl | 34 +++++++++++++++++++-------- src/baselaplace/optimize_prior.jl | 4 ++-- src/baselaplace/predicting.jl | 2 +- src/baselaplace/prior.jl | 38 ++++++++++++++++++++++--------- src/baselaplace/utils.jl | 20 ++++++++-------- 8 files changed, 71 insertions(+), 38 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 22e67ad..acf0276 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,6 @@ jobs: fail-fast: false matrix: version: - - '1.9' - '1.10' - '1' os: diff --git a/CHANGELOG.md b/CHANGELOG.md index 6068366..e9e6a87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,14 +7,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.2.1]. -## Version [2.0.0] - 2024-11-26 +## Version [2.0.0] - 2024-12-03 +### Changed +- Largely removed unicode characters from code base. [#134] +- Removed legacy v1.9 from CI testing. [#134] ### Added -- added support to MLJ [#126] [#134] - +- Added general support for MLJ [#126] [#134] ## Version [1.1.1] - 2024-09-12 diff --git a/Project.toml b/Project.toml index 0124164..55a27f5 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ MLJBase = "1" MLJModelInterface = "1.8.0" MLUtils = "0.4" Optimisers = "0.2, 0.3" -Random = "1.9, 1.10" +Random = "1" Statistics = "1" Tables = "1.10.1" Test = "1" diff --git a/src/baselaplace/core_struct.jl b/src/baselaplace/core_struct.jl index 41599bd..c180b88 100644 --- a/src/baselaplace/core_struct.jl +++ b/src/baselaplace/core_struct.jl @@ -26,10 +26,10 @@ Container for the parameters of a Laplace approximation. - `hessian_structure::HessianStructure`: the structure of the Hessian. Possible values are `:full` and `:kron` or a concrete subtype of `HessianStructure`. - `backend::Symbol`: the backend to use. Possible values are `:GGN` and `:Fisher`. - `curvature::Union{Curvature.CurvatureInterface,Nothing}`: the curvature interface. Possible values are `nothing` or a concrete subtype of `CurvatureInterface`. -- `σ::Real`: the observation noise -- `μ₀::Real`: the prior mean -- `λ::Real`: the prior precision -- `P₀::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix +- `observational_noise::Real`: the observation noise +- `prior_mean::Real`: the prior mean of the network parameters. +- `prio_precision::Real`: the prior precision for the network parameters. +- `prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix for the network parameters. """ Base.@kwdef struct LaplaceParams subset_of_weights::Symbol = :all @@ -37,10 +37,26 @@ Base.@kwdef struct LaplaceParams hessian_structure::Union{HessianStructure,Symbol,String} = FullHessian() backend::Symbol = :GGN curvature::Union{Curvature.CurvatureInterface,Nothing} = nothing - σ::Real = 1.0 - μ₀::Real = 0.0 - λ::Real = 1.0 - P₀::Union{Nothing,AbstractMatrix,UniformScaling} = nothing + observational_noise::Real = 1.0 + prior_mean::Real = 0.0 + prior_precision::Real = 1.0 + prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling} = nothing +end + +function Base.getproperty(ce::LaplaceParams, sym::Symbol) + sym = sym === :σ ? :observational_noise : sym + sym = sym === :μ₀ ? :prior_mean : sym + sym = sym === :λ ? :prior_precision : sym + sym = sym === :P₀ ? :prior_precision_matrix : sym + return Base.getfield(ce, sym) +end + +function Base.setproperty!(ce::LaplaceParams, sym::Symbol, val) + sym = sym === :σ ? :observational_noise : sym + sym = sym === :μ₀ ? :prior_mean : sym + sym = sym === :λ ? :prior_precision : sym + sym = sym === :P₀ ? :prior_precision_matrix : sym + return Base.setfield!(ce, sym, val) end include("estimation_params.jl") @@ -96,7 +112,7 @@ la = Laplace(nn, likelihood=:regression) """ function Laplace(model::Any; likelihood::Symbol, kwargs...) args = LaplaceParams(; kwargs...) - @assert !(args.σ != 1.0 && likelihood != :regression) "Observation noise σ ≠ 1 only available for regression." + @assert !(args.observational_noise != 1.0 && likelihood != :regression) "Observation noise σ ≠ 1 only available for regression." # Unpack arguments and wrap in containers: est_args = EstimationParams(args, model, likelihood) diff --git a/src/baselaplace/optimize_prior.jl b/src/baselaplace/optimize_prior.jl index 08f3c49..dd253a7 100644 --- a/src/baselaplace/optimize_prior.jl +++ b/src/baselaplace/optimize_prior.jl @@ -19,8 +19,8 @@ function optimize_prior!( ) # Setup: - logP₀ = isnothing(λinit) ? log.(unique(diag(la.prior.P₀))) : log.([λinit]) # prior precision (scalar) - logσ = isnothing(σinit) ? log.([la.prior.σ]) : log.([σinit]) # noise (scalar) + logP₀ = isnothing(λinit) ? log.(unique(diag(la.prior.prior_precision_matrix))) : log.([λinit]) # prior precision (scalar) + logσ = isnothing(σinit) ? log.([la.prior.observational_noise]) : log.([σinit]) # noise (scalar) opt = Adam(lr) show_every = round(n_steps / 10) i = 0 diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index feb457b..62e82fd 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -135,7 +135,7 @@ function predict( if la.likelihood == :regression # Add observational noise: - pred_var = fvar .+ la.prior.σ^2 + pred_var = fvar .+ la.prior.observational_noise^2 fstd = sqrt.(pred_var) pred_dist = [Normal(fμ[i], fstd[i]) for i in axes(fμ, 2)] diff --git a/src/baselaplace/prior.jl b/src/baselaplace/prior.jl index accec86..dbb844e 100644 --- a/src/baselaplace/prior.jl +++ b/src/baselaplace/prior.jl @@ -5,16 +5,32 @@ Container for the prior parameters of a Laplace approximation. # Fields -- `σ::Real`: the observation noise -- `μ₀::Real`: the prior mean -- `λ::Real`: the prior precision -- `P₀::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix +- `observational_noise::Real`: the observation noise +- `prior_mean::Real`: the prior mean +- `prior_precision::Real`: the prior precision +- `prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix """ mutable struct Prior - σ::Real - μ₀::Real - λ::Real - P₀::Union{Nothing,AbstractMatrix,UniformScaling} + observational_noise::Real + prior_mean::Real + prior_precision::Real + prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling} +end + +function Base.getproperty(ce::Prior, sym::Symbol) + sym = sym === :σ ? :observational_noise : sym + sym = sym === :μ₀ ? :prior_mean : sym + sym = sym === :λ ? :prior_precision : sym + sym = sym === :P₀ ? :prior_precision_matrix : sym + return Base.getfield(ce, sym) +end + +function Base.setproperty!(ce::Prior, sym::Symbol, val) + sym = sym === :σ ? :observational_noise : sym + sym = sym === :μ₀ ? :prior_mean : sym + sym = sym === :λ ? :prior_precision : sym + sym = sym === :P₀ ? :prior_precision_matrix : sym + return Base.setfield!(ce, sym, val) end """ @@ -23,16 +39,16 @@ end Extracts the prior parameters from a `LaplaceParams` object. """ function Prior(params::LaplaceParams, model::Any, likelihood::Symbol) - P₀ = params.P₀ + P₀ = params.prior_precision_matrix n = LaplaceRedux.n_params(model, EstimationParams(params, model, likelihood)) if typeof(P₀) <: UniformScaling P₀ = P₀(n) elseif isnothing(P₀) - P₀ = UniformScaling(params.λ)(n) + P₀ = UniformScaling(params.prior_precision)(n) end # Sanity: if isa(P₀, AbstractMatrix) @assert all(size(P₀) .== n) "Dimensions of prior Hessian $(size(P₀)) do not align with number of parameters ($n)" end - return Prior(params.σ, params.μ₀, params.λ, P₀) + return Prior(params.observational_noise, params.prior_mean, params.prior_precision, P₀) end diff --git a/src/baselaplace/utils.jl b/src/baselaplace/utils.jl index 853a3e0..7c486cc 100644 --- a/src/baselaplace/utils.jl +++ b/src/baselaplace/utils.jl @@ -18,7 +18,7 @@ LaplaceRedux.n_params(la::Laplace) = LaplaceRedux.n_params(la.model, la.est_para Helper function to extract the prior mean of the parameters from a Laplace approximation. """ function get_prior_mean(la::Laplace) - return la.prior.μ₀ + return la.prior.prior_mean end """ @@ -27,7 +27,7 @@ end Helper function to extract the prior precision matrix from a Laplace approximation. """ function prior_precision(la::Laplace) - return la.prior.P₀ + return la.prior.prior_precision_matrix end """ @@ -39,7 +39,7 @@ on the last layer of the NN, of a `Flux.Chain` with Laplace approximation. outdim(la::AbstractLaplace) = outdim(la.model) @doc raw""" - posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.P₀) + posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.prior_precision_matrix) Computes the posterior precision ``P`` for a fitted Laplace Approximation as follows, @@ -47,7 +47,7 @@ Computes the posterior precision ``P`` for a fitted Laplace Approximation as fol where ``\sum_{n=1}^N\nabla_{\theta}^2\log p(\mathcal{D}_n|\theta)|_{\hat\theta}=H`` is the Hessian and ``\nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}=P_0`` is the prior precision and ``\hat\theta`` is the MAP estimate. """ -function posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.P₀) +function posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.prior_precision_matrix) @assert !isnothing(H) "Hessian not available. Either no value supplied or Laplace Approximation has not yet been estimated." return H + P₀ end @@ -70,7 +70,7 @@ end function log_likelihood(la::AbstractLaplace) factor = -_H_factor(la) if la.likelihood == :regression - c = la.posterior.n_data * la.posterior.n_out * log(la.prior.σ * sqrt(2 * pi)) + c = la.posterior.n_data * la.posterior.n_out * log(la.prior.observational_noise * sqrt(2 * pi)) else c = 0 end @@ -82,7 +82,7 @@ end Returns the factor σ⁻², where σ is used in the zero-centered Gaussian prior p(θ) = N(θ;0,σ²I) """ -_H_factor(la::AbstractLaplace) = 1 / (la.prior.σ^2) +_H_factor(la::AbstractLaplace) = 1 / (la.prior.observational_noise^2) """ _init_H(la::AbstractLaplace) @@ -120,14 +120,14 @@ function log_marginal_likelihood( # update prior precision: if !isnothing(P₀) - la.prior.P₀ = + la.prior.prior_precision_matrix = typeof(P₀) <: AbstractFloat ? UniformScaling(P₀)(la.posterior.n_params) : P₀ end # update observation noise: if !isnothing(σ) - @assert (la.likelihood == :regression || la.prior.σ == σ) "Can only change observational noise σ for regression." - la.prior.σ = σ + @assert (la.likelihood == :regression || la.prior.observational_noise == σ) "Can only change observational noise σ for regression." + la.prior.observational_noise = σ end return log_likelihood(la) - 0.5 * (log_det_ratio(la) + _weight_penalty(la)) @@ -147,7 +147,7 @@ end """ -log_det_prior_precision(la::AbstractLaplace) = sum(log.(diag(la.prior.P₀))) +log_det_prior_precision(la::AbstractLaplace) = sum(log.(diag(la.prior.prior_precision_matrix))) """ log_det_posterior_precision(la::AbstractLaplace) From 04c5a99caccd904e3197819eeedc43b2f0d42c8f Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Dec 2024 09:33:35 +0100 Subject: [PATCH 05/10] only some unicode left in user-facing function --- src/baselaplace/prior.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/baselaplace/prior.jl b/src/baselaplace/prior.jl index dbb844e..34e0eb5 100644 --- a/src/baselaplace/prior.jl +++ b/src/baselaplace/prior.jl @@ -39,16 +39,16 @@ end Extracts the prior parameters from a `LaplaceParams` object. """ function Prior(params::LaplaceParams, model::Any, likelihood::Symbol) - P₀ = params.prior_precision_matrix + prior_precision_matrix = params.prior_precision_matrix n = LaplaceRedux.n_params(model, EstimationParams(params, model, likelihood)) - if typeof(P₀) <: UniformScaling - P₀ = P₀(n) - elseif isnothing(P₀) - P₀ = UniformScaling(params.prior_precision)(n) + if typeof(prior_precision_matrix) <: UniformScaling + prior_precision_matrix = prior_precision_matrix(n) + elseif isnothing(prior_precision_matrix) + prior_precision_matrix = UniformScaling(params.prior_precision)(n) end # Sanity: - if isa(P₀, AbstractMatrix) - @assert all(size(P₀) .== n) "Dimensions of prior Hessian $(size(P₀)) do not align with number of parameters ($n)" + if isa(prior_precision_matrix, AbstractMatrix) + @assert all(size(prior_precision_matrix) .== n) "Dimensions of prior Hessian $(size(prior_precision_matrix)) do not align with number of parameters ($n)" end - return Prior(params.observational_noise, params.prior_mean, params.prior_precision, P₀) + return Prior(params.observational_noise, params.prior_mean, params.prior_precision, prior_precision_matrix) end From 22a436a684d52dd58b9928eb9a3ee4c78c209775 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Dec 2024 09:49:28 +0100 Subject: [PATCH 06/10] let's see --- src/baselaplace/posterior.jl | 20 ++++++++++++++++---- src/baselaplace/utils.jl | 2 +- src/direct_mlj.jl | 4 ++-- src/full.jl | 2 +- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/baselaplace/posterior.jl b/src/baselaplace/posterior.jl index ea0e4c5..c8d1dde 100644 --- a/src/baselaplace/posterior.jl +++ b/src/baselaplace/posterior.jl @@ -5,26 +5,38 @@ Container for the results of a Laplace approximation. # Fields -- `μ::AbstractVector`: the MAP estimate of the parameters +- `posterior_mean::AbstractVector`: the MAP estimate of the parameters - `H::Union{AbstractArray,AbstractDecomposition,Nothing}`: the Hessian matrix - `P::Union{AbstractArray,AbstractDecomposition,Nothing}`: the posterior precision matrix -- `Σ::Union{AbstractArray,Nothing}`: the posterior covariance matrix +- `posterior_covariance_matrix::Union{AbstractArray,Nothing}`: the posterior covariance matrix - `n_data::Union{Int,Nothing}`: the number of data points - `n_params::Union{Int,Nothing}`: the number of parameters - `n_out::Union{Int,Nothing}`: the number of outputs - `loss::Real`: the loss value """ mutable struct Posterior - μ::AbstractVector + posterior_mean::AbstractVector H::Union{AbstractArray,AbstractDecomposition,Nothing} P::Union{AbstractArray,AbstractDecomposition,Nothing} - Σ::Union{AbstractArray,Nothing} + posterior_covariance_matrix::Union{AbstractArray,Nothing} n_data::Union{Int,Nothing} n_params::Union{Int,Nothing} n_out::Union{Int,Nothing} loss::Real end +function Base.getproperty(ce::Posterior, sym::Symbol) + sym = sym === :μ ? :posterior_mean : sym + sym = sym === :Σ ? :posterior_covariance_matrix : sym + return Base.getfield(ce, sym) +end + +function Base.setproperty!(ce::Posterior, sym::Symbol, val) + sym = sym === :μ ? :posterior_mean : sym + sym = sym === :Σ ? :posterior_covariance_matrix : sym + return Base.setfield!(ce, sym, val) +end + """ Posterior(model::Any, est_params::EstimationParams) diff --git a/src/baselaplace/utils.jl b/src/baselaplace/utils.jl index 7c486cc..e473c3f 100644 --- a/src/baselaplace/utils.jl +++ b/src/baselaplace/utils.jl @@ -100,7 +100,7 @@ Smaller weights in a neural network can result in a model that is more stable an making a prediction on new data. """ function _weight_penalty(la::AbstractLaplace) - μ = la.posterior.μ + μ = la.posterior.posterior_mean μ₀ = get_prior_mean(la) Δ = μ .- μ₀ P₀ = prior_precision(la) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index 9996efc..cec2459 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -484,10 +484,10 @@ function MMI.fitted_params(model::LaplaceModels, fitresult) la, decode = fitresult posterior = la.posterior return ( - mean=posterior.μ, + mean=posterior.posterior_mean, H=posterior.H, P=posterior.P, - cov_matrix=posterior.Σ, + cov_matrix=posterior.posterior_covariance_matrix, n_data=posterior.n_data, n_params=posterior.n_params, n_out=posterior.n_out, diff --git a/src/full.jl b/src/full.jl index 4720150..6cab8ec 100644 --- a/src/full.jl +++ b/src/full.jl @@ -46,7 +46,7 @@ function _fit!( la.posterior.H = H la.posterior.loss = loss la.posterior.P = posterior_precision(la) - la.posterior.Σ = posterior_covariance(la, la.posterior.P) + la.posterior.posterior_covariance_matrix = posterior_covariance(la, la.posterior.P) return la.posterior.n_data = n_data end From ec9785fc9879add5a032853267317a079f4e8046 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Dec 2024 10:11:00 +0100 Subject: [PATCH 07/10] that should do it --- src/direct_mlj.jl | 18 +++++++++--------- test/laplace.jl | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/direct_mlj.jl b/src/direct_mlj.jl index cec2459..e8d1014 100644 --- a/src/direct_mlj.jl +++ b/src/direct_mlj.jl @@ -211,9 +211,9 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y) subnetwork_indices=m.subnetwork_indices, hessian_structure=m.hessian_structure, backend=m.backend, - σ=m.observational_noise, - μ₀=m.prior_mean, - P₀=m.prior_precision_matrix, + observational_noise=m.observational_noise, + prior_mean=m.prior_mean, + prior_precision_matrix=m.prior_precision_matrix, ) if typeof(m) == LaplaceClassifier @@ -300,9 +300,9 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) subnetwork_indices=m.subnetwork_indices, hessian_structure=m.hessian_structure, backend=m.backend, - σ=m.observational_noise, - μ₀=m.prior_mean, - P₀=m.prior_precision_matrix, + observational_noise=m.observational_noise, + prior_mean=m.prior_mean, + prior_precision_matrix=m.prior_precision_matrix, ) if typeof(m) == LaplaceClassifier la.likelihood = :classification @@ -347,9 +347,9 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y) subnetwork_indices=m.subnetwork_indices, hessian_structure=m.hessian_structure, backend=m.backend, - σ=m.observational_noise, - μ₀=m.prior_mean, - P₀=m.prior_precision_matrix, + observational_noise=m.observational_noise, + prior_mean=m.prior_mean, + prior_precision_matrix=m.prior_precision_matrix, ) if typeof(m) == LaplaceClassifier la.likelihood = :classification diff --git a/test/laplace.jl b/test/laplace.jl index ea4cfbb..8994103 100644 --- a/test/laplace.jl +++ b/test/laplace.jl @@ -334,7 +334,7 @@ function run_workflow( la = Laplace( nn; likelihood=likelihood, - λ=λ, + prior_precision=λ, subset_of_weights=subset_of_weights, backend=backend, subnetwork_indices=subnetwork_indices, From 55432673f6afb93cc93a531b3b3b54e9b680cc8a Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Dec 2024 10:32:18 +0100 Subject: [PATCH 08/10] let's go --- src/baselaplace/core_struct.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/baselaplace/core_struct.jl b/src/baselaplace/core_struct.jl index c180b88..a6cfd0d 100644 --- a/src/baselaplace/core_struct.jl +++ b/src/baselaplace/core_struct.jl @@ -27,9 +27,13 @@ Container for the parameters of a Laplace approximation. - `backend::Symbol`: the backend to use. Possible values are `:GGN` and `:Fisher`. - `curvature::Union{Curvature.CurvatureInterface,Nothing}`: the curvature interface. Possible values are `nothing` or a concrete subtype of `CurvatureInterface`. - `observational_noise::Real`: the observation noise +- `σ::Real`: alias for `observational_noise`. - `prior_mean::Real`: the prior mean of the network parameters. +- `μ₀::Real`: alias for `prior_mean`. - `prio_precision::Real`: the prior precision for the network parameters. +- `λ::Real`: alias for `prior_precision`. - `prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix for the network parameters. +- `P₀::Union{Nothing,AbstractMatrix,UniformScaling}`: alias for `prior_precision_matrix`. """ Base.@kwdef struct LaplaceParams subset_of_weights::Symbol = :all @@ -38,9 +42,13 @@ Base.@kwdef struct LaplaceParams backend::Symbol = :GGN curvature::Union{Curvature.CurvatureInterface,Nothing} = nothing observational_noise::Real = 1.0 + σ::Real = observational_noise prior_mean::Real = 0.0 + μ₀::Real = prior_mean prior_precision::Real = 1.0 + λ::Real = prior_precision prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling} = nothing + P₀::Union{Nothing,AbstractMatrix,UniformScaling} = prior_precision_matrix end function Base.getproperty(ce::LaplaceParams, sym::Symbol) From 36c8127b472b6280fbc0fc70637816a5440f9038 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Dec 2024 10:59:40 +0100 Subject: [PATCH 09/10] bumped version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 55a27f5..c9b8e81 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "1.1.1" +version = "2.0.0" [deps] CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" From acca82317acb246ecf6aa5bb1bb0dfd6d24817c7 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Dec 2024 11:13:47 +0100 Subject: [PATCH 10/10] uh --- CHANGELOG.md | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9e6a87..a0b9851 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v0.2.1]. -## Version [2.0.0] - 2024-12-03 +## Version [1.2.0] - 2024-12-03 ### Changed diff --git a/Project.toml b/Project.toml index c9b8e81..609cd6a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "2.0.0" +version = "1.2.0" [deps] CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"