Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

88 fix docstrings #89

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

*Note*: We try to adhere to these practices as of version [v0.2.1].

## Version [0.2.1] - 2024-05-29

### Changed

- Improved the docstring for the `predict` and `glm_predictive_distribution` methods. [#88]

### Added

- Added `probit` helper function to compute probit approximation for classification. [#88]
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LaplaceRedux"
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
authors = ["Patrick Altmeyer"]
version = "0.2.0"
version = "0.2.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
59 changes: 44 additions & 15 deletions src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,32 @@ end
glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)

Computes the linearized GLM predictive.

# Arguments

- `la::AbstractLaplace`: A Laplace object.
- `X::AbstractArray`: Input data.

# Returns

- `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux.

# Examples

```julia-repl
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
glm_predictive_distribution(la, hcat(x...))
"""
function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
𝐉, fμ = Curvature.jacobians(la.est_params.curvature, X)
fμ = reshape(fμ, Flux.outputsize(la.model, size(X)))
fvar = functional_variance(la, 𝐉)
fvar = reshape(fvar, size(fμ)...)
return fμ, fvar
Expand All @@ -24,14 +47,27 @@ end

Computes predictions from Bayesian neural network.

# Arguments

- `la::AbstractLaplace`: A Laplace object.
- `X::AbstractArray`: Input data.
- `link_approx::Symbol=:probit`: Link function approximation. Options are `:probit` and `:plugin`.
- `predict_proba::Bool=true`: If `true` (default), returns probabilities for classification tasks.

# Returns

- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux.

# Examples

```julia-repl
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn)
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
predict(la, hcat(x...))
```
Expand All @@ -51,8 +87,7 @@ function predict(

# Probit approximation
if link_approx == :probit
κ = 1 ./ sqrt.(1 .+ π / 8 .* fvar)
z = κ .* fμ
z = probit(fμ, fvar)
end

if link_approx == :plugin
Expand All @@ -75,20 +110,14 @@ function predict(
end

"""
predict(la::AbstractLaplace, X::Matrix; link_approx=:probit, predict_proba::Bool=true)

Compute predictive posteriors for a batch of inputs.
probit(fμ::AbstractArray, fvar::AbstractArray)

Predicts on a matrix of inputs. Note, input is assumed to be batched only if it is a matrix.
If the input dimensionality of the model is 1 (a vector), one should still prepare a 1×B matrix batch as input.
Compute the probit approximation of the predictive distribution.
"""
function predict(
la::AbstractLaplace, X::Matrix; link_approx=:probit, predict_proba::Bool=true
)
return stack([
predict(la, X[:, i]; link_approx=link_approx, predict_proba=predict_proba) for
i in 1:size(X, 2)
])
function probit(fμ::AbstractArray, fvar::AbstractArray)
κ = 1 ./ sqrt.(1 .+ π / 8 .* fvar)
z = κ .* fμ
return z
end

"""
Expand Down
Loading