Skip to content

Commit

Permalink
fit and computation of scores seems to work now
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Jul 29, 2024
1 parent 49d6857 commit c877b5f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
InferOpt = "4846b161-c94e-4150-8dac-c7ae193c601f"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
Expand Down
26 changes: 15 additions & 11 deletions src/conformal_models/inductive_bayes_regression.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Simple
#using LaplaceRedux.LaplaceRegression
"The `BayesRegressor` is the simplest approach to Inductive Conformalized Bayes."
mutable struct BayesRegressor{Model <: Supervised} <: ConformalInterval
model::Model
Expand All @@ -9,19 +9,16 @@
end

function ConformalBayes(y, fμ, fvar)
# Ensure σ is positive
if fvar <= 0
throw(ArgumentError("variance must be positive"))
end
std= sqrt.(fvar)
# Compute the probability density
coeff = 1 ./ (std .* sqrt(2 * π))
exponent = -((y .- fμ).^2) ./ (2 .* std.^2)
return -coeff .* exp.(exponent)
end

function BayesRegressor(model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=ConformalBayes(y, fμ, fvar), train_ratio::AbstractFloat=0.5)
#@assert typeof(model) == :Laplace "Model must be of type Laplace"
function BayesRegressor(model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=ConformalBayes, train_ratio::AbstractFloat=0.5)
#@assert typeof(model.model) == :Laplace "Model must be of type Laplace"
#@assert typeof(model)== LaplaceRegression "Model must be of type Laplace"
return BayesRegressor(model, coverage, nothing, heuristic, train_ratio)
end

Expand All @@ -37,15 +34,22 @@
A typical choice for the heuristic function is ``h(\hat\mu(X_i), Y_i)=1-\hat\mu(X_i)_{Y_i}`` where ``\hat\mu(X_i)_{Y_i}`` denotes the softmax output of the true class and ``\hat\mu`` denotes the model fitted on training data ``\mathcal{D}_{\text{train}}``. The simple approach only takes the softmax probability of the true label into account.
"""
function MMI.fit(conf_model::BayesRegressor, verbosity, X, y)

# Data Splitting:
Xtrain, ytrain, Xcal, ycal = split_data(conf_model, X, y)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, MMI.reformat(conf_model.model, Xcal)...)
fitresult, cache, report = MMI.fit(
conf_model.model, verbosity, MMI.reformat(conf_model.model, Xtrain, ytrain)...)


# Nonconformity Scores:
fμ, fvar = MMI.predict(conf_model.model, fitresult, Xcal)
yhat = MMI.predict(conf_model.model, fitresult, Xcal)

= vcat([x[1] for x in yhat]...)
fvar = vcat([x[2] for x in yhat]...)



conf_model.scores = @.(conf_model.heuristic(ycal, fμ, fvar))

return (fitresult, cache, report)
Expand All @@ -67,5 +71,5 @@
v = conf_model.scores
= qplus(v, conf_model.coverage)
#normal_distr = [Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 2)]
return
return, fvar
end

0 comments on commit c877b5f

Please sign in to comment.