Skip to content

Commit

Permalink
this should work
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed May 29, 2024
1 parent cac236d commit 00c6932
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ Computes the linearized GLM predictive.
# Returns
- `fμ::AbstractArray`: Mean of the predictive distribution. The format is column-major as in Flux.
- `fvar::AbstractArray`: Variance of the predictive distribution. The format is column-major as in Flux.
- `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux.
# Examples
Expand All @@ -36,7 +36,7 @@ glm_predictive_distribution(la, hcat(x...))
"""
function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
𝐉, fμ = Curvature.jacobians(la.est_params.curvature, X)
= permutedims(fμ)
= reshape(fμ, Flux.outputsize(la.model, size(X)))
fvar = functional_variance(la, 𝐉)
fvar = reshape(fvar, size(fμ)...)
return fμ, fvar
Expand All @@ -56,7 +56,8 @@ Computes predictions from Bayesian neural network.
# Returns
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The format is column-major as in Flux.
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux.
# Examples
Expand Down

0 comments on commit 00c6932

Please sign in to comment.