diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..916da018 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,17 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +*Note*: We try to adhere to these practices as of version [v0.2.1]. + +## Version [0.2.1] - 2024-05-29 + +### Changed + +- Improved the docstring for the `predict` and `glm_predictive_distribution` methods. [#88] + +### Added + +- Added `probit` helper function to compute probit approximation for classification. [#88] \ No newline at end of file diff --git a/Project.toml b/Project.toml index 2ad5609a..24355ed3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "0.2.0" +version = "0.2.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 19d67b6c..49773a23 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -11,9 +11,32 @@ end glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) Computes the linearized GLM predictive. + +# Arguments + +- `la::AbstractLaplace`: A Laplace object. +- `X::AbstractArray`: Input data. + +# Returns + +- `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux. +- `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux. + +# Examples + +```julia-repl +using Flux, LaplaceRedux +using LaplaceRedux.Data: toy_data_linear +x, y = toy_data_linear() +data = zip(x,y) +nn = Chain(Dense(2,1)) +la = Laplace(nn; likelihood=:classification) +fit!(la, data) +glm_predictive_distribution(la, hcat(x...)) """ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) 𝐉, fμ = Curvature.jacobians(la.est_params.curvature, X) + fμ = reshape(fμ, Flux.outputsize(la.model, size(X))) fvar = functional_variance(la, 𝐉) fvar = reshape(fvar, size(fμ)...) return fμ, fvar @@ -24,14 +47,27 @@ end Computes predictions from Bayesian neural network. +# Arguments + +- `la::AbstractLaplace`: A Laplace object. +- `X::AbstractArray`: Input data. +- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`. +- `predict_proba::Bool=true`: If `true` (default), returns probabilities for classification tasks. + +# Returns + +- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux. +- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux. + # Examples ```julia-repl using Flux, LaplaceRedux +using LaplaceRedux.Data: toy_data_linear x, y = toy_data_linear() data = zip(x,y) nn = Chain(Dense(2,1)) -la = Laplace(nn) +la = Laplace(nn; likelihood=:classification) fit!(la, data) predict(la, hcat(x...)) ``` @@ -51,8 +87,7 @@ function predict( # Probit approximation if link_approx == :probit - κ = 1 ./ sqrt.(1 .+ π / 8 .* fvar) - z = κ .* fμ + z = probit(fμ, fvar) end if link_approx == :plugin @@ -75,20 +110,14 @@ function predict( end """ - predict(la::AbstractLaplace, X::Matrix; link_approx=:probit, predict_proba::Bool=true) - -Compute predictive posteriors for a batch of inputs. + probit(fμ::AbstractArray, fvar::AbstractArray) -Predicts on a matrix of inputs. Note, input is assumed to be batched only if it is a matrix. -If the input dimensionality of the model is 1 (a vector), one should still prepare a 1×B matrix batch as input. +Compute the probit approximation of the predictive distribution. """ -function predict( - la::AbstractLaplace, X::Matrix; link_approx=:probit, predict_proba::Bool=true -) - return stack([ - predict(la, X[:, i]; link_approx=link_approx, predict_proba=predict_proba) for - i in 1:size(X, 2) - ]) +function probit(fμ::AbstractArray, fvar::AbstractArray) + κ = 1 ./ sqrt.(1 .+ π / 8 .* fvar) + z = κ .* fμ + return z end """