Skip to content

Commit

Permalink
changed score function
Browse files Browse the repository at this point in the history
  • Loading branch information
pasq-cat committed Jul 29, 2024
1 parent 3584534 commit 49d6857
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ Overloads the `score` function for the `MLJFluxModel` type.
function ConformalPrediction.score(
conf_model::BayesRegressor, ::Type{<:MLJFluxModel}, fitresult, X, y
)
X = permutedims(matrix(X))
= permutedims(fitresult[1](X))
X = matrix(X)
fμ, fvar = fitresult[1](X)

scores = @.(conf_model.heuristic(y, ŷ))
return scores
end
Expand Down
29 changes: 21 additions & 8 deletions src/conformal_models/inductive_bayes_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,20 @@
train_ratio::AbstractFloat
end

function BayesRegressor(model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=f(y, ŷ)=-ŷ, train_ratio::AbstractFloat=0.5)
@assert typeof(model) == :Laplace "Model must be of type Laplace"
function ConformalBayes(y, fμ, fvar)
# Ensure σ is positive
if fvar <= 0
throw(ArgumentError("variance must be positive"))
end
std= sqrt.(fvar)
# Compute the probability density
coeff = 1 ./ (std .* sqrt(2 * π))
exponent = -((y .- fμ).^2) ./ (2 .* std.^2)
return -coeff .* exp.(exponent)
end

function BayesRegressor(model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=ConformalBayes(y, fμ, fvar), train_ratio::AbstractFloat=0.5)
#@assert typeof(model) == :Laplace "Model must be of type Laplace"
return BayesRegressor(model, coverage, nothing, heuristic, train_ratio)
end

Expand All @@ -33,8 +45,8 @@
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, MMI.reformat(conf_model.model, Xcal)...)

# Nonconformity Scores:
= pdf.(MMI.predict(conf_model.model, fitresult, Xcal), ycal) # predict returns a vector of distributions
conf_model.scores = @.(conf_model.heuristic(ycal, ))
fμ, fvar = MMI.predict(conf_model.model, fitresult, Xcal)
conf_model.scores = @.(conf_model.heuristic(ycal, fμ, fvar))

return (fitresult, cache, report)
end
Expand All @@ -51,8 +63,9 @@
where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data.
"""
function MMI.predict(conf_model::BayesRegressor, fitresult, Xnew)
= MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...)
v = conf_model.scores
= qplus(v, conf_model.coverage)
return
fμ, fvar = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...)
v = conf_model.scores
= qplus(v, conf_model.coverage)
#normal_distr = [Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 2)]
return
end

0 comments on commit 49d6857

Please sign in to comment.