Skip to content

Commit

Permalink
tests passing and streamlined
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Mar 31, 2023
1 parent 56cca1c commit b4c7140
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 187 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
2 changes: 1 addition & 1 deletion src/conformal_models/conformal_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ const available_models = Dict(
:transductive => Dict(:naive => NaiveClassifier),
:inductive => Dict(
:simple_inductive => SimpleInductiveClassifier,
:trainable_simple_inductive => TrainableSimpleInductiveClassifier,
:adaptive_inductive => AdaptiveInductiveClassifier,
),
),
Expand All @@ -120,6 +119,7 @@ const tested_atomic_models = Dict(
:evo_tree => :(@load EvoTreeClassifier pkg = EvoTrees),
:nearest_neighbor => :(@load KNNClassifier pkg = NearestNeighborModels),
:light_gbm => :(@load LGBMClassifier pkg = LightGBM),
:flux => :(@load NeuralNetworkClassifier pkg = MLJFlux),
),
)

Expand Down
17 changes: 9 additions & 8 deletions src/conformal_models/inductive_classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ function SimpleInductiveClassifier(
end

function score(conf_model::SimpleInductiveClassifier, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
# X = isa(X, Matrix) ? table(X) : X
# p̂ = reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, X))
# L = p̂.decoder.classes
# probas = pdf(p̂, L)
X = size(X,2) == 1 ? X : permutedims(X)
probas = permutedims(fitresult[1](X))
score(conf_model, typeof(conf_model.model), fitresult, X, y)
end

function score(conf_model::SimpleInductiveClassifier, ::Type{<:Supervised}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
= reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, X))
L =.decoder.classes
probas = pdf(p̂, L)
scores = @.(conf_model.heuristic(probas))
if isnothing(y)
return scores
else
cal_scores = getindex.(Ref(scores), 1:size(scores,1), levelcode.(y))
cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y))
return cal_scores, scores
end
end
Expand Down Expand Up @@ -59,7 +60,7 @@ function MMI.fit(conf_model::SimpleInductiveClassifier, verbosity, X, y)
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)

# Nonconformity Scores:
cal_scores, scores = score(conf_model, fitresult, matrix(Xcal), ycal)
cal_scores, scores = score(conf_model, fitresult, Xcal, ycal)
conf_model.scores = Dict(
:calibration => cal_scores,
:all => scores,
Expand Down
178 changes: 3 additions & 175 deletions src/conformal_models/training/inductive_classification.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,7 @@
using MLJFlux: MLJFluxModel

"The `TrainableSimpleInductiveClassifier` is the simplest approach to Inductive Conformal Classification. Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset."
mutable struct TrainableSimpleInductiveClassifier{Model<:MLJFluxModel} <: ConformalProbabilisticSet
model::Model
coverage::AbstractFloat
scores::Union{Nothing,Dict{Any,Any}}
heuristic::Function
train_ratio::AbstractFloat
end

function TrainableSimpleInductiveClassifier(
model::MLJFluxModel;
coverage::AbstractFloat = 0.95,
heuristic::Function = f(p̂) = 1.0 - p̂,
train_ratio::AbstractFloat = 0.5,
)
return TrainableSimpleInductiveClassifier(model, coverage, nothing, heuristic, train_ratio)
end

function score(conf_model::TrainableSimpleInductiveClassifier, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
X = size(X,2) == 1 ? X : permutedims(X)
function score(conf_model::SimpleInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
X = permutedims(matrix(X))
probas = permutedims(fitresult[1](X))
scores = @.(conf_model.heuristic(probas))
if isnothing(y)
Expand All @@ -28,158 +10,4 @@ function score(conf_model::TrainableSimpleInductiveClassifier, fitresult, X, y::
cal_scores = getindex.(Ref(scores), 1:size(scores,1), levelcode.(y))
return cal_scores, scores
end
end

@doc raw"""
MMI.fit(conf_model::TrainableSimpleInductiveClassifier, verbosity, X, y)
For the [`TrainableSimpleInductiveClassifier`](@ref) nonconformity scores are computed as follows:
``
S_i^{\text{CAL}} = s(X_i, Y_i) = h(\hat\mu(X_i), Y_i), \ i \in \mathcal{D}_{\text{calibration}}
``
A typical choice for the heuristic function is ``h(\hat\mu(X_i), Y_i)=1-\hat\mu(X_i)_{Y_i}`` where ``\hat\mu(X_i)_{Y_i}`` denotes the softmax output of the true class and ``\hat\mu`` denotes the model fitted on training data ``\mathcal{D}_{\text{train}}``. The simple approach only takes the softmax probability of the true label into account.
"""
function MMI.fit(conf_model::TrainableSimpleInductiveClassifier, verbosity, X, y)

# Data Splitting:
train, calibration = partition(eachindex(y), conf_model.train_ratio)
Xtrain = selectrows(X, train)
ytrain = y[train]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)
Xcal = selectrows(X, calibration)
ycal = y[calibration]
Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)

# Nonconformity Scores:
cal_scores, scores = score(conf_model, fitresult, matrix(Xcal), ycal)
conf_model.scores = Dict(
:calibration => cal_scores,
:all => scores,
)

return (fitresult, cache, report)
end

@doc raw"""
MMI.predict(conf_model::TrainableSimpleInductiveClassifier, fitresult, Xnew)
For the [`TrainableSimpleInductiveClassifier`](@ref) prediction sets are computed as follows,
``
\hat{C}_{n,\alpha}(X_{n+1}) = \left\{y: s(X_{n+1},y) \le \hat{q}_{n, \alpha}^{+} \{S_i^{\text{CAL}}\} \right\}, \ i \in \mathcal{D}_{\text{calibration}}
``
where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data.
"""
function MMI.predict(conf_model::TrainableSimpleInductiveClassifier, fitresult, Xnew)
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...),
)
v = conf_model.scores[:calibration]
= StatsBase.quantile(v, conf_model.coverage)
= map(p̂) do pp
L =.decoder.classes
probas = pdf.(pp, L)
is_in_set = 1.0 .- probas .<=
if !all(is_in_set .== false)
pp = UnivariateFinite(L[is_in_set], probas[is_in_set])
else
pp = missing
end
return pp
end
return
end

# Adaptive
"The `TrainableAdaptiveInductiveClassifier` is an improvement to the [`TrainableSimpleInductiveClassifier`](@ref) and the [`NaiveClassifier`](@ref). Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset like the [`TrainableSimpleInductiveClassifier`](@ref). Contrary to the [`TrainableSimpleInductiveClassifier`](@ref) it utilizes the softmax output of all classes."
mutable struct TrainableAdaptiveInductiveClassifier{Model<:MLJFluxModel} <: ConformalProbabilisticSet
model::Model
coverage::AbstractFloat
scores::Union{Nothing,AbstractArray}
heuristic::Function
train_ratio::AbstractFloat
end

function TrainableAdaptiveInductiveClassifier(
model::MLJFluxModel;
coverage::AbstractFloat = 0.95,
heuristic::Function = f(y, ŷ) = 1.0 - ŷ,
train_ratio::AbstractFloat = 0.5,
)
return TrainableAdaptiveInductiveClassifier(model, coverage, nothing, heuristic, train_ratio)
end

@doc raw"""
MMI.fit(conf_model::TrainableAdaptiveInductiveClassifier, verbosity, X, y)
For the [`TrainableAdaptiveInductiveClassifier`](@ref) nonconformity scores are computed by cumulatively summing the ranked scores of each label in descending order until reaching the true label ``Y_i``:
``
S_i^{\text{CAL}} = s(X_i,Y_i) = \sum_{j=1}^k \hat\mu(X_i)_{\pi_j} \ \text{where } \ Y_i=\pi_k, i \in \mathcal{D}_{\text{calibration}}
``
"""
function MMI.fit(conf_model::TrainableAdaptiveInductiveClassifier, verbosity, X, y)

# Data Splitting:
train, calibration = partition(eachindex(y), conf_model.train_ratio)
Xtrain = selectrows(X, train)
ytrain = y[train]
Xtrain, ytrain = MMI.reformat(conf_model.model, Xtrain, ytrain)
Xcal = selectrows(X, calibration)
ycal = y[calibration]
Xcal, ycal = MMI.reformat(conf_model.model, Xcal, ycal)

# Training:
fitresult, cache, report = MMI.fit(conf_model.model, verbosity, Xtrain, ytrain)

# Nonconformity Scores:
= reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, Xcal))
L =.decoder.classes
= pdf(p̂, L) # compute probabilities for all classes
scores = map(eachrow(ŷ), eachrow(ycal)) do ŷᵢ, ycalᵢ
ranks = sortperm(.-ŷᵢ) # rank in descending order
index_y = findall(L[ranks] .== ycalᵢ)[1] # index of true y in sorted array
scoreᵢ = last(cumsum(ŷᵢ[ranks][1:index_y])) # sum up until true y is reached
return scoreᵢ
end
conf_model.scores = scores

return (fitresult, cache, report)
end

@doc raw"""
MMI.predict(conf_model::TrainableAdaptiveInductiveClassifier, fitresult, Xnew)
For the [`TrainableAdaptiveInductiveClassifier`](@ref) prediction sets are computed as follows,
``
\hat{C}_{n,\alpha}(X_{n+1}) = \left\{y: s(X_{n+1},y) \le \hat{q}_{n, \alpha}^{+} \{S_i^{\text{CAL}}\} \right\}, i \in \mathcal{D}_{\text{calibration}}
``
where ``\mathcal{D}_{\text{calibration}}`` denotes the designated calibration data.
"""
function MMI.predict(conf_model::TrainableAdaptiveInductiveClassifier, fitresult, Xnew)
= reformat_mlj_prediction(
MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, Xnew)...),
)
v = conf_model.scores
= StatsBase.quantile(v, conf_model.coverage)
= map(p̂) do pp
L =.decoder.classes
probas = pdf.(pp, L)
is_in_set = 1.0 .- probas .<=
if !all(is_in_set .== false)
pp = UnivariateFinite(L[is_in_set], probas[is_in_set])
else
pp = missing
end
return pp
end
return
end
end
Loading

0 comments on commit b4c7140

Please sign in to comment.