Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

134 breaking remove unicode characters #135

Merged
merged 10 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v0.2.1].


## Version [2.0.0] - 2024-11-26



### Added

- added support to MLJ [#126] [#134]


## Version [1.1.1] - 2024-09-12

### Changed
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/regression.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ then we can plot the calibration plot of our neural model

```{julia}
#| output: true
Calibration_Plot(la,y_test,vec(predicted_distributions);n_bins = 20)
calibration_plot(la,y_test,vec(predicted_distributions);n_bins = 20)
```

and compute the sharpness of the predictive distribution
Expand Down
90 changes: 54 additions & 36 deletions src/direct_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJBase.Probabilistic
hessian_structure::Union{HessianStructure,Symbol,String} =
:full::(_ in (:full, :diagonal))
backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
observational_noise::Float64 = 1.0
prior_mean::Float64 = 0.0
prior_precision_matrix::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
fit_prior_nsteps::Int = 100::(_ > 0)
link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end
Expand All @@ -37,14 +37,32 @@ MLJBase.@mlj_model mutable struct LaplaceRegressor <: MLJBase.Probabilistic
hessian_structure::Union{HessianStructure,Symbol,String} =
:full::(_ in (:full, :diagonal))
backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
σ::Float64 = 1.0
μ₀::Float64 = 0.0
P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
observational_noise::Float64 = 1.0
prior_mean::Float64 = 0.0
prior_precision_matrix::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
fit_prior_nsteps::Int = 100::(_ > 0)
end

LaplaceModels = Union{LaplaceRegressor,LaplaceClassifier}

# Aliases
const LM = LaplaceModels
function Base.getproperty(ce::LM, sym::Symbol)
sym = sym === :σ ? :observational_noise : sym
sym = sym === :μ₀ ? :prior_mean : sym
sym = sym === :P₀ ? :prior_precision_matrix : sym
return Base.getfield(ce, sym)
end
function Base.setproperty!(ce::LM, sym::Symbol, val)
sym = sym === :σ ? :observational_noise : sym
sym = sym === :μ₀ ? :prior_mean : sym
sym = sym === :P₀ ? :prior_precision_matrix : sym
return Base.setfield!(ce, sym, val)
end




# for fit:
function MMI.reformat(::LaplaceRegressor, X, y)
return (MLJBase.matrix(X) |> permutedims, (reshape(y, 1, :), nothing))
Expand Down Expand Up @@ -193,9 +211,9 @@ function MMI.fit(m::LaplaceModels, verbosity, X, y)
subnetwork_indices=m.subnetwork_indices,
hessian_structure=m.hessian_structure,
backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
σ=m.observational_noise,
μ₀=m.prior_mean,
P₀=m.prior_precision_matrix,
)

if typeof(m) == LaplaceClassifier
Expand Down Expand Up @@ -282,9 +300,9 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
subnetwork_indices=m.subnetwork_indices,
hessian_structure=m.hessian_structure,
backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
σ=m.observational_noise,
μ₀=m.prior_mean,
P₀=m.prior_precision_matrix,
)
if typeof(m) == LaplaceClassifier
la.likelihood = :classification
Expand Down Expand Up @@ -315,9 +333,9 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
:subnetwork_indices,
:hessian_structure,
:backend,
:σ,
:μ₀,
:P₀,
:observational_noise,
:prior_mean,
:prior_precision_matrix,
)
println(" updating only the laplace optimization part")
old_la = old_fitresult[1]
Expand All @@ -329,9 +347,9 @@ function MMI.update(m::LaplaceModels, verbosity, old_fitresult, old_cache, X, y)
subnetwork_indices=m.subnetwork_indices,
hessian_structure=m.hessian_structure,
backend=m.backend,
σ=m.σ,
μ₀=m.μ₀,
P₀=m.P₀,
σ=m.observational_noise,
μ₀=m.prior_mean,
P₀=m.prior_precision_matrix,
)
if typeof(m) == LaplaceClassifier
la.likelihood = :classification
Expand Down Expand Up @@ -452,10 +470,10 @@ end

# Returns
A named tuple containing:
- `μ`: The mean of the posterior distribution.
- `mean`: The mean of the posterior distribution.
- `H`: The Hessian of the posterior distribution.
- `P`: The precision matrix of the posterior distribution.
- `Σ`: The covariance matrix of the posterior distribution.
- `cov_matrix`: The covariance matrix of the posterior distribution.
- `n_data`: The number of data points.
- `n_params`: The number of parameters.
- `n_out`: The number of outputs.
Expand All @@ -466,10 +484,10 @@ function MMI.fitted_params(model::LaplaceModels, fitresult)
la, decode = fitresult
posterior = la.posterior
return (
μ=posterior.μ,
mean=posterior.μ,
H=posterior.H,
P=posterior.P,
Σ=posterior.Σ,
cov_matrix=posterior.Σ,
n_data=posterior.n_data,
n_params=posterior.n_params,
n_out=posterior.n_out,
Expand Down Expand Up @@ -517,7 +535,7 @@ function MMI.predict(m::LaplaceModels, fitresult, Xnew)
means, variances = yhat

# Create Normal distributions from the means and variances
return vec([Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)])
return vec([Normal(mean, sqrt(variance)) for (mean, variance) in zip(means, variances)])

else
predictions =
Expand Down Expand Up @@ -551,7 +569,7 @@ MMI.metadata_pkg(
MLJBase.metadata_model(
LaplaceClassifier;
input_scitype=Union{
AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types
AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Infinite}}, # matrix with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
},
target_scitype=AbstractArray{<:MLJBase.Finite}, # ordered factor or multiclass
Expand All @@ -562,8 +580,8 @@ MLJBase.metadata_model(
MLJBase.metadata_model(
LaplaceRegressor;
input_scitype=Union{
AbstractMatrix{<:Union{MLJBase.Finite,MLJBase.Continuous}}, # matrix with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Continuous), # table with mixed types
AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Infinite}}, # matrix with mixed types
MLJBase.Table(MLJBase.Finite, MLJBase.Infinite), # table with mixed types
},
target_scitype=AbstractArray{MLJBase.Continuous},
supports_training_losses=true,
Expand Down Expand Up @@ -626,11 +644,11 @@ Train the machine using `fit!(mach, rows=...)`.

- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`.

- `σ::Float64 = 1.0`: the standard deviation of the prior distribution.
- `observational_noise (alias σ)::Float64 = 1.0`: the standard deviation of the prior distribution.

- `μ₀::Float64 = 0.0`: the mean of the prior distribution.
- `prior_mean (alias μ₀)::Float64 = 0.0`: the mean of the prior distribution.

- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution.
- `prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution.

- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors.

Expand All @@ -650,13 +668,13 @@ Train the machine using `fit!(mach, rows=...)`.

The fields of `fitted_params(mach)` are:

- `μ`: The mean of the posterior distribution.
- `mean`: The mean of the posterior distribution.

- `H`: The Hessian of the posterior distribution.

- `P`: The precision matrix of the posterior distribution.

- `Σ`: The covariance matrix of the posterior distribution.
- `cov_matrix`: The covariance matrix of the posterior distribution.

- `n_data`: The number of data points.

Expand Down Expand Up @@ -766,11 +784,11 @@ Train the machine using `fit!(mach, rows=...)`.

- `backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))`: the backend to use, either `:GGN` or `:EmpiricalFisher`.

- `σ::Float64 = 1.0`: the standard deviation of the prior distribution.
- `observational_noise (alias σ)::Float64 = 1.0`: the standard deviation of the prior distribution.

- `μ₀::Float64 = 0.0`: the mean of the prior distribution.
- `prior_mean (alias μ₀)::Float64 = 0.0`: the mean of the prior distribution.

- `P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution.
- `prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing`: the covariance matrix of the prior distribution.

- `fit_prior_nsteps::Int = 100::(_ > 0) `: the number of steps used to fit the priors.

Expand All @@ -789,13 +807,13 @@ Train the machine using `fit!(mach, rows=...)`.

The fields of `fitted_params(mach)` are:

- `μ`: The mean of the posterior distribution.
- `mean`: The mean of the posterior distribution.

- `H`: The Hessian of the posterior distribution.

- `P`: The precision matrix of the posterior distribution.

- `Σ`: The covariance matrix of the posterior distribution.
- `cov_matrix`: The covariance matrix of the posterior distribution.

- `n_data`: The number of data points.

Expand Down
Loading