Skip to content

Commit

Permalink
trying this another way
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Mar 31, 2023
1 parent ecdb162 commit 56cca1c
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 12 deletions.
3 changes: 0 additions & 3 deletions src/ConformalPrediction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ export ConformalModel
export conformal_model, fit, predict
export available_models, tested_atomic_models
export set_size

# Conformal Training:
include("training/training.jl")
export soft_assignment

# Evaluation:
Expand Down
11 changes: 9 additions & 2 deletions src/conformal_models/conformal_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ include("transductive_regression.jl")
include("inductive_classification.jl")
include("transductive_classification.jl")

# Training:
include("training/training.jl")

# Type unions:
const InductiveModel =
Union{SimpleInductiveRegressor,SimpleInductiveClassifier,AdaptiveInductiveClassifier}
const InductiveModel = Union{
SimpleInductiveRegressor,
SimpleInductiveClassifier,
AdaptiveInductiveClassifier,
}

const TransductiveModel = Union{
NaiveRegressor,
Expand Down Expand Up @@ -95,6 +101,7 @@ const available_models = Dict(
:transductive => Dict(:naive => NaiveClassifier),
:inductive => Dict(
:simple_inductive => SimpleInductiveClassifier,
:trainable_simple_inductive => TrainableSimpleInductiveClassifier,
:adaptive_inductive => AdaptiveInductiveClassifier,
),
),
Expand Down
2 changes: 1 addition & 1 deletion src/conformal_models/inductive_classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function MMI.predict(conf_model::SimpleInductiveClassifier, fitresult, Xnew)
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...),
)
v = conf_model.scores
v = conf_model.scores[:calibration]
= StatsBase.quantile(v, conf_model.coverage)
= map(p̂) do pp
L =.decoder.classes
Expand Down
6 changes: 3 additions & 3 deletions src/conformal_models/plotting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function Plots.contourf(
if isnothing(target)
@info "No target label supplied, using first."
end
target = isnothing(target) ? 1 : target
target = isnothing(target) ? levels(y)[1] : target
if plot_set_size
_default_title = "Set size"
elseif plot_set_loss
Expand Down Expand Up @@ -137,7 +137,7 @@ function Plots.contourf(
push!(Z, z)
end
Z = reduce(hcat, Z)
Z = Z[Int(target), :]
Z = Z[findall.(levels(y) .== target)[1][1], :]

# Contour:
if plot_set_size
Expand Down Expand Up @@ -331,4 +331,4 @@ function Plots.bar(
x = sort(levels(idx), lt = natural)
y = [sum(idx .== _x) for _x in x]
Plots.bar(x, y; label = label, xtickfontsize = xtickfontsize, kwrgs...)
end
end
185 changes: 185 additions & 0 deletions src/conformal_models/training/inductive_classification.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
using MLJFlux: MLJFluxModel

"The `TrainableSimpleInductiveClassifier` is the simplest approach to Inductive Conformal Classification. Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset."
mutable struct TrainableSimpleInductiveClassifier{Model<:MLJFluxModel} <: ConformalProbabilisticSet
model::Model
coverage::AbstractFloat
scores::Union{Nothing,Dict{Any,Any}}
heuristic::Function
train_ratio::AbstractFloat
end

function TrainableSimpleInductiveClassifier(
model::MLJFluxModel;
coverage::AbstractFloat = 0.95,
heuristic::Function = f(p̂) = 1.0 - p̂,
train_ratio::AbstractFloat = 0.5,
)
return TrainableSimpleInductiveClassifier(model, coverage, nothing, heuristic, train_ratio)
end

function score(conf_model::TrainableSimpleInductiveClassifier, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
X = size(X,2) == 1 ? X : permutedims(X)
probas = permutedims(fitresult[1](X))
scores = @.(conf_model.heuristic(probas))
if isnothing(y)
return scores
else
cal_scores = getindex.(Ref(scores), 1:size(scores,1), levelcode.(y))
return cal_scores, scores
end
end

@doc raw"""
MMI.fit(conf_model::TrainableSimpleInductiveClassifier, verbosity, X, y)
For the [`TrainableSimpleInductiveClassifier`](@ref) nonconformity scores are computed as follows:
``
S_i^{\text{CAL}} = s(X_i, Y_i) = h(\hat\mu(X_i), Y_i), \ i \in \mathcal{D}_{\text{calibration}}
``
A typical choice for the heuristic function is ``h(\hat\mu(X_i), Y_i)=1-\hat\mu(X_i)_{Y_i}`` where ``\hat\mu(X_i)_{Y_i}`` denotes the softmax output of the true class and ``\hat\mu`` denotes the model fitted on training data ``\mathcal{D}_{\text{train}}``. The simple approach only takes the softmax probability of the true label into account.
"""
function MMI.fit(conf_model::TrainableSimpleInductiveClassifier, verbosity, X, y)

# Data Splitting:
train, calibration = partition(eachindex(y), conf_model.train_ratio)
Xtrain = selectrows(X, train)
ytrain = y[train]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)
Xcal = selectrows(X, calibration)
ycal = y[calibration]
Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)

# Nonconformity Scores:
cal_scores, scores = score(conf_model, fitresult, matrix(Xcal), ycal)
conf_model.scores = Dict(
:calibration => cal_scores,
:all => scores,
)

return (fitresult, cache, report)
end

@doc raw"""
MMI.predict(conf_model::TrainableSimpleInductiveClassifier, fitresult, Xnew)
For the [`TrainableSimpleInductiveClassifier`](@ref) prediction sets are computed as follows,
``
\hat{C}_{n,\alpha}(X_{n+1}) = \left\{y: s(X_{n+1},y) \le \hat{q}_{n, \alpha}^{+} \{S_i^{\text{CAL}}\} \right\}, \ i \in \mathcal{D}_{\text{calibration}}
``
where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data.
"""
function MMI.predict(conf_model::TrainableSimpleInductiveClassifier, fitresult, Xnew)
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...),
)
v = conf_model.scores[:calibration]
= StatsBase.quantile(v, conf_model.coverage)
= map(p̂) do pp
L =.decoder.classes
probas = pdf.(pp, L)
is_in_set = 1.0 .- probas .<=
if !all(is_in_set .== false)
pp = UnivariateFinite(L[is_in_set], probas[is_in_set])
else
pp = missing
end
return pp
end
return
end

# Adaptive
"The `TrainableAdaptiveInductiveClassifier` is an improvement to the [`TrainableSimpleInductiveClassifier`](@ref) and the [`NaiveClassifier`](@ref). Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset like the [`TrainableSimpleInductiveClassifier`](@ref). Contrary to the [`TrainableSimpleInductiveClassifier`](@ref) it utilizes the softmax output of all classes."
mutable struct TrainableAdaptiveInductiveClassifier{Model<:MLJFluxModel} <: ConformalProbabilisticSet
model::Model
coverage::AbstractFloat
scores::Union{Nothing,AbstractArray}
heuristic::Function
train_ratio::AbstractFloat
end

function TrainableAdaptiveInductiveClassifier(
model::MLJFluxModel;
coverage::AbstractFloat = 0.95,
heuristic::Function = f(y, ŷ) = 1.0 - ŷ,
train_ratio::AbstractFloat = 0.5,
)
return TrainableAdaptiveInductiveClassifier(model, coverage, nothing, heuristic, train_ratio)
end

@doc raw"""
MMI.fit(conf_model::TrainableAdaptiveInductiveClassifier, verbosity, X, y)
For the [`TrainableAdaptiveInductiveClassifier`](@ref) nonconformity scores are computed by cumulatively summing the ranked scores of each label in descending order until reaching the true label ``Y_i``:
``
S_i^{\text{CAL}} = s(X_i,Y_i) = \sum_{j=1}^k \hat\mu(X_i)_{\pi_j} \ \text{where } \ Y_i=\pi_k, i \in \mathcal{D}_{\text{calibration}}
``
"""
function MMI.fit(conf_model::TrainableAdaptiveInductiveClassifier, verbosity, X, y)

# Data Splitting:
train, calibration = partition(eachindex(y), conf_model.train_ratio)
Xtrain = selectrows(X, train)
ytrain = y[train]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)
Xcal = selectrows(X, calibration)
ycal = y[calibration]
Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)

# Nonconformity Scores:
= reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, Xcal))
L =.decoder.classes
= pdf(p̂, L) # compute probabilities for all classes
scores = map(eachrow(ŷ), eachrow(ycal)) do ŷᵢ, ycalᵢ
ranks = sortperm(.-ŷᵢ) # rank in descending order
index_y = findall(L[ranks] .== ycalᵢ)[1] # index of true y in sorted array
scoreᵢ = last(cumsum(ŷᵢ[ranks][1:index_y])) # sum up until true y is reached
return scoreᵢ
end
conf_model.scores = scores

return (fitresult, cache, report)
end

@doc raw"""
MMI.predict(conf_model::TrainableAdaptiveInductiveClassifier, fitresult, Xnew)
For the [`TrainableAdaptiveInductiveClassifier`](@ref) prediction sets are computed as follows,
``
\hat{C}_{n,\alpha}(X_{n+1}) = \left\{y: s(X_{n+1},y) \le \hat{q}_{n, \alpha}^{+} \{S_i^{\text{CAL}}\} \right\}, i \in \mathcal{D}_{\text{calibration}}
``
where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data.
"""
function MMI.predict(conf_model::TrainableAdaptiveInductiveClassifier, fitresult, Xnew)
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...),
)
v = conf_model.scores
= StatsBase.quantile(v, conf_model.coverage)
= map(p̂) do pp
L =.decoder.classes
probas = pdf.(pp, L)
is_in_set = 1.0 .- probas .<=
if !all(is_in_set .== false)
pp = UnivariateFinite(L[is_in_set], probas[is_in_set])
else
pp = missing
end
return pp
end
return
end
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Computes soft assignment scores for each label and sample. That is, the probabil
function soft_assignment(conf_model::ConformalProbabilisticSet; temp::Union{Nothing, Real}=nothing)
temp = isnothing(temp) ? 0.5 : temp
v = sort(conf_model.scores[:calibration])
= Statistics.quantile(v, conf_model.coverage, sorted=true)
= StatsBase.quantile(v, conf_model.coverage, sorted=true)
scores = conf_model.scores[:all]
return @.(σ((q̂ - scores) / temp))
end
Expand All @@ -23,7 +23,7 @@ This function can be used to compute soft assigment probabilities for new data `
function soft_assignment(conf_model::ConformalProbabilisticSet, fitresult, X; temp::Union{Nothing, Real}=nothing)
temp = isnothing(temp) ? 0.5 : temp
v = sort(conf_model.scores[:calibration])
= Statistics.quantile(v, conf_model.coverage, sorted=true)
= StatsBase.quantile(v, conf_model.coverage, sorted=true)
scores = score(conf_model, fitresult, X)
return @.(σ((q̂ - scores) / temp))
end
Expand Down
2 changes: 2 additions & 0 deletions src/conformal_models/training/training.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
include("losses.jl")
include("inductive_classification.jl")
1 change: 0 additions & 1 deletion src/training/training.jl

This file was deleted.

0 comments on commit 56cca1c

Please sign in to comment.