-
Notifications
You must be signed in to change notification settings - Fork 12
/
inductive_classification.jl
executable file
·123 lines (115 loc) · 4.36 KB
/
inductive_classification.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
using CategoricalArrays
using ConformalPrediction: SimpleInductiveClassifier, AdaptiveInductiveClassifier
using MLJEnsembles: EitherEnsembleModel
using MLJFlux: MLJFluxModel, reformat
using MLUtils
"""
ConformalPrediction.score(conf_model::InductiveModel, model::MLJFluxModel, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Overloads the `score` function for the `MLJFluxModel` type.
"""
function ConformalPrediction.score(
conf_model::SimpleInductiveClassifier,
::Type{<:MLJFluxModel},
fitresult,
X,
y::Union{Nothing,AbstractArray}=nothing,
)
X = permutedims(matrix(X))
probas = permutedims(fitresult[1](X))
scores = @.(conf_model.heuristic(y, probas))
if isnothing(y)
return scores
else
cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y))
return cal_scores, scores
end
end
"""
ConformalPrediction.score(conf_model::SimpleInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Overloads the `score` function for ensembles of `MLJFluxModel` types.
"""
function ConformalPrediction.score(
conf_model::SimpleInductiveClassifier,
::Type{<:EitherEnsembleModel{<:MLJFluxModel}},
fitresult,
X,
y::Union{Nothing,AbstractArray}=nothing,
)
X = permutedims(matrix(X))
_chains = map(res -> res[1], fitresult.ensemble)
probas =
MLUtils.stack(map(chain -> chain(X), _chains)) |>
p ->
mean(p; dims=ndims(p)) |>
p -> MLUtils.unstack(p; dims=ndims(p))[1] |> p -> permutedims(p)
scores = @.(conf_model.heuristic(y, probas))
if isnothing(y)
return scores
else
cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y))
return cal_scores, scores
end
end
"""
ConformalPrediction.score(conf_model::AdaptiveInductiveClassifier, ::Type{<:MLJFluxModel}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Overloads the `score` function for the `MLJFluxModel` type.
"""
function score(
conf_model::AdaptiveInductiveClassifier,
::Type{<:MLJFluxModel},
fitresult,
X,
y::Union{Nothing,AbstractArray}=nothing,
)
L = levels(fitresult[2])
X = reformat(X)
X = typeof(X) <: AbstractArray ? X : permutedims(matrix(X))
probas = permutedims(fitresult[1](X)) # compute probabilities for all classes
scores = map(Base.Iterators.product(eachrow(probas), L)) do Z
probasᵢ, yₖ = Z
ranks = sortperm(.-probasᵢ) # rank in descending order
index_y = findall(L[ranks] .== yₖ)[1] # index of true y in sorted array
scoresᵢ = last(cumsum(probasᵢ[ranks][1:index_y])) # sum up until true y is reached
return scoresᵢ
end
if isnothing(y)
return scores
else
cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y))
return cal_scores, scores
end
end
"""
ConformalPrediction.score(conf_model::AdaptiveInductiveClassifier, ::Type{<:EitherEnsembleModel{<:MLJFluxModel}}, fitresult, X, y::Union{Nothing,AbstractArray}=nothing)
Overloads the `score` function for ensembles of `MLJFluxModel` types.
"""
function score(
conf_model::AdaptiveInductiveClassifier,
::Type{<:EitherEnsembleModel{<:MLJFluxModel}},
fitresult,
X,
y::Union{Nothing,AbstractArray}=nothing,
)
L = levels(fitresult.ensemble[1][2])
X = reformat(X)
X = typeof(X) <: AbstractArray ? X : permutedims(matrix(X))
_chains = map(res -> res[1], fitresult.ensemble)
probas =
MLUtils.stack(map(chain -> chain(X), _chains)) |>
p ->
mean(p; dims=ndims(p)) |>
p -> MLUtils.unstack(p; dims=ndims(p))[1] |> p -> permutedims(p)
scores = map(Base.Iterators.product(eachrow(probas), L)) do Z
probasᵢ, yₖ = Z
ranks = sortperm(.-probasᵢ) # rank in descending order
index_y = findall(L[ranks] .== yₖ)[1] # index of true y in sorted array
scoresᵢ = last(cumsum(probasᵢ[ranks][1:index_y])) # sum up until true y is reached
return scoresᵢ
end
if isnothing(y)
return scores
else
cal_scores = getindex.(Ref(scores), 1:size(scores, 1), levelcode.(y))
return cal_scores, scores
end
end