diff --git a/src/baselaplace/predicting.jl b/src/baselaplace/predicting.jl index 6f6cf21a..49773a23 100644 --- a/src/baselaplace/predicting.jl +++ b/src/baselaplace/predicting.jl @@ -19,8 +19,8 @@ Computes the linearized GLM predictive. # Returns -- `fμ::AbstractArray`: Mean of the predictive distribution. The format is column-major as in Flux. -- `fvar::AbstractArray`: Variance of the predictive distribution. The format is column-major as in Flux. +- `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux. +- `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux. # Examples @@ -36,7 +36,7 @@ glm_predictive_distribution(la, hcat(x...)) """ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) 𝐉, fμ = Curvature.jacobians(la.est_params.curvature, X) - fμ = permutedims(fμ) + fμ = reshape(fμ, Flux.outputsize(la.model, size(X))) fvar = functional_variance(la, 𝐉) fvar = reshape(fvar, size(fμ)...) return fμ, fvar @@ -56,7 +56,8 @@ Computes predictions from Bayesian neural network. # Returns -- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The format is column-major as in Flux. +- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux. +- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux. # Examples