Skip to content

Commit

Permalink
Merge pull request #103 from JuliaTrustworthyAI/101-remaining-issue-o…
Browse files Browse the repository at this point in the history
…n-the-mljinterface

101 remaining issue on the mljinterface
  • Loading branch information
pat-alt authored Jul 22, 2024
2 parents 7d70d2c + 240b192 commit 4509807
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 27 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v0.2.1].

## Version [1.0.1] - 2024-07-19

### Changed
- added the option to return meand and variance to predict in the case of regression[[#101](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/issues/101)]
- modified mlj_flux.jl by adding the ret_distr parameter and fixed mljflux.predict both for classification and regression tasks.

## Version [1.0.0] - 2024-07-17

### Changed
Expand Down
6 changes: 5 additions & 1 deletion src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ function predict(

# Regression:
if la.likelihood == :regression
return reshape(normal_distr, (:, 1))
if ret_distr
return reshape(normal_distr, (:, 1))
else
return fμ, fvar
end
end

# Classification:
Expand Down
46 changes: 30 additions & 16 deletions src/mlj_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ using Optimisers: Optimisers
A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
The model is trained using the `fit!` method. The model is defined by the following default parameters:
The model is defined by the following default parameters for all `MLJFlux` models:
- `builder`: a Flux model that constructs the neural network.
- `optimiser`: a Flux optimiser.
Expand All @@ -27,13 +28,17 @@ The model is trained using the `fit!` method. The model is defined by the follow
- `rng`: a random number generator.
- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining.
- `acceleration`: the computational resource to use.
The model also has the following parameters, which are specific to the Laplace approximation:
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
- `σ`: the standard deviation of the prior distribution.
- `μ₀`: the mean of the prior distribution.
- `P₀`: the covariance matrix of the prior distribution.
- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic
Expand All @@ -55,16 +60,17 @@ MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilis
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
end

"""
MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic
A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type.
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
The model is trained using the `fit!` method. The model is defined by the following default parameters:
It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.
The model is defined by the following default parameters for all `MLJFlux` models:
- `builder`: a Flux model that constructs the neural network.
- `finaliser`: a Flux model that processes the output of the neural network.
- `optimiser`: a Flux optimiser.
Expand All @@ -76,13 +82,19 @@ A mutable struct representing a Laplace Classification model that extends the ML
- `rng`: a random number generator.
- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining.
- `acceleration`: the computational resource to use.
The model also has the following parameters, which are specific to the Laplace approximation:
- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`.
- `subnetwork_indices`: the indices of the subnetworks.
- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`.
- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`.
- `σ`: the standard deviation of the prior distribution.
- `μ₀`: the mean of the prior distribution.
- `P₀`: the covariance matrix of the prior distribution.
- `link_approx`: the link approximation to use, either `:probit` or `:plugin`.
- `predict_proba`: a boolean that select whether to predict probabilities or not.
- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
- `fit_prior_nsteps`: the number of steps used to fit the priors.
"""
MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic
Expand All @@ -107,6 +119,7 @@ MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbab
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
predict_proba::Bool = true::(_ in (true, false))
ret_distr::Bool = false::(_ in (true, false))
fit_prior_nsteps::Int = 100::(_ > 0)
end

Expand Down Expand Up @@ -273,15 +286,11 @@ Predict the output for new input data using a Laplace regression model.
"""
function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew)
Xnew = MLJBase.matrix(Xnew)

model = fitresult[1]
la = fitresult[1]
#convert in a vector of vectors because MLJ ask to do so
X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)]
#inizialize output vector yhat
yhat = []
X_vec = collect(eachrow(Xnew))
# Predict using Laplace and collect the predictions
yhat = [glm_predictive_distribution(model, x_vec) for x_vec in X_vec]

yhat = [map(x -> LaplaceRedux.predict(la, x; ret_distr=model.ret_distr), X_vec)...]
return yhat
end

Expand Down Expand Up @@ -448,13 +457,18 @@ function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew)
la = fitresult[1]
Xnew = MLJBase.matrix(Xnew)
#convert in a vector of vectors because Laplace ask to do so
X_vec = X_vec = [Xnew[i, :] for i in 1:size(Xnew, 1)]

# Predict using Laplace and collect the predictions
X_vec = collect(eachrow(Xnew))
predictions = [
LaplaceRedux.predict(
la, x; link_approx=model.link_approx, predict_proba=model.predict_proba
) for x in X_vec
map(
x -> LaplaceRedux.predict(
la,
x;
link_approx=model.link_approx,
predict_proba=model.predict_proba,
ret_distr=model.ret_distr,
),
X_vec,
)...,
]

return predictions
Expand Down
7 changes: 1 addition & 6 deletions test/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,7 @@ end
la = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)
fit!(la, data)
matrix_normals = Matrix{Normal{T}} where {T<:AbstractFloat}
@test typeof(predict(la, X)) <: matrix_normals

#predict(la, X[1]; link_approx=:plugin)
#predict(la, X[1]; link_approx=:probit)
#predict(la, X[1]; ret_distr=true)
#predict(la, X[1]; ret_distr=true, predict_proba=false)
@test typeof(predict(la, X; ret_distr=true)) <: matrix_normals
end

#testing the function LaplaceRedux.has_softmax_or_sigmoid_final_layer
Expand Down
8 changes: 4 additions & 4 deletions test/mlj_flux_interfacing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using StableRNGs
subset_of_weights=:incorrect,
hessian_structure=:incorrect,
backend=:incorrect,
ret_distr=true,
)

fitresult, cache, _report = MLJBase.fit(model, 0, X, y)
Expand All @@ -46,6 +47,8 @@ using StableRNGs
history = _report.training_losses
@test length(history) == model.epochs + 1

yhat = MLJBase.predict(model, fitresult, X)

# start fresh with small epochs:
model = LaplaceRegression(;
builder=builder,
Expand Down Expand Up @@ -88,13 +91,10 @@ using StableRNGs
N = 300
X = MLJBase.table(rand(Float32, N, 4))
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)

builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu)
optimiser = Flux.Optimise.Adam(0.03)

y = ycont

@test basictest_regression(X, y, builder, optimiser, 0.9)
@test basictest_regression(X, ycont, builder, optimiser, 0.9)
end

@testset "Classification" begin
Expand Down

0 comments on commit 4509807

Please sign in to comment.