-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathconformal_models.jl
127 lines (109 loc) · 4.91 KB
/
conformal_models.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
using MLJBase
import MLJModelInterface as MMI
import MLJModelInterface: predict, fit, save, restore
"An abstract base type for conformal models that produce interval-valued predictions. This includes most conformal regression models."
abstract type ConformalInterval <: MMI.Interval end
"An abstract base type for conformal models that produce set-valued probabilistic predictions. This includes most conformal classification models."
abstract type ConformalProbabilisticSet <: MMI.ProbabilisticSet end
"An abstract base type for conformal models that produce probabilistic predictions. This includes some conformal classifier like Venn-ABERS."
abstract type ConformalProbabilistic <: MMI.Probabilistic end
const ConformalModel = Union{
ConformalInterval,ConformalProbabilisticSet,ConformalProbabilistic
}
include("utils.jl")
include("heuristics.jl")
# Main API call to wrap model:
"""
conformal_model(model::Supervised; method::Union{Nothing, Symbol}=nothing, kwargs...)
A simple wrapper function that turns a `model::Supervised` into a conformal model. It accepts an optional key argument that can be used to specify the desired `method` for conformal prediction as well as additinal `kwargs...` specific to the `method`.
"""
function conformal_model(
model::Supervised; method::Union{Nothing,Symbol}=nothing, kwargs...
)
is_classifier = target_scitype(model) <: AbstractVector{<:Finite}
if isnothing(method)
_method = is_classifier ? SimpleInductiveClassifier : SimpleInductiveRegressor
else
if is_classifier
classification_methods = merge(values(available_models[:classification])...)
@assert method in keys(classification_methods) "$(method) is not a valid method for classifiers."
_method = classification_methods[method]
else
regression_methods = merge(values(available_models[:regression])...)
@assert method in keys(regression_methods) "$(method) is not a valid method for regressors."
_method = regression_methods[method]
end
end
conf_model = _method(model; kwargs...)
return conf_model
end
# Regression Models:
include("inductive_regression.jl")
include("transductive_regression.jl")
# Classification Models
include("inductive_classification.jl")
include("transductive_classification.jl")
# Training:
include("ConformalTraining/ConformalTraining.jl")
using .ConformalTraining
# Type unions:
const InductiveModel = Union{
SimpleInductiveRegressor,SimpleInductiveClassifier,AdaptiveInductiveClassifier
}
const TransductiveModel = Union{
NaiveRegressor,
JackknifeRegressor,
JackknifePlusRegressor,
JackknifePlusAbRegressor,
JackknifePlusAbMinMaxRegressor,
JackknifeMinMaxRegressor,
CVPlusRegressor,
CVMinMaxRegressor,
NaiveClassifier,
TimeSeriesRegressorEnsembleBatch,
}
"A container listing all available methods for conformal prediction."
const available_models = Dict(
:regression => Dict(
:transductive => Dict(
:naive => NaiveRegressor,
:jackknife => JackknifeRegressor,
:jackknife_plus => JackknifePlusRegressor,
:jackknife_minmax => JackknifeMinMaxRegressor,
:cv_plus => CVPlusRegressor,
:cv_minmax => CVMinMaxRegressor,
:jackknife_plus_ab => JackknifePlusAbRegressor,
:jackknife_plus_ab_minmax => JackknifePlusAbMinMaxRegressor,
:time_series_ensemble_batch => TimeSeriesRegressorEnsembleBatch,
),
:inductive => Dict(:simple_inductive => SimpleInductiveRegressor),
),
:classification => Dict(
:transductive => Dict(:naive => NaiveClassifier),
:inductive => Dict(
:simple_inductive => SimpleInductiveClassifier,
:adaptive_inductive => AdaptiveInductiveClassifier,
),
),
)
"A container listing all atomic MLJ models that have been tested for use with this package."
const tested_atomic_models = Dict(
:regression => Dict(
:linear => :(@load LinearRegressor pkg = MLJLinearModels),
:ridge => :(@load RidgeRegressor pkg = MLJLinearModels),
:lasso => :(@load LassoRegressor pkg = MLJLinearModels),
:evo_tree => :(@load EvoTreeRegressor pkg = EvoTrees),
:nearest_neighbor => :(@load KNNRegressor pkg = NearestNeighborModels),
# :light_gbm => :(@load LGBMRegressor pkg = LightGBM),
# :neural_network => :(@load NeuralNetworkRegressor pkg = MLJFlux),
# :symbolic_regression => (@load SRRegressor pkg = SymbolicRegression),
),
:classification => Dict(
:logistic => :(@load LogisticClassifier pkg = MLJLinearModels),
:evo_tree => :(@load EvoTreeClassifier pkg = EvoTrees),
:nearest_neighbor => :(@load KNNClassifier pkg = NearestNeighborModels),
# :light_gbm => :(@load LGBMClassifier pkg = LightGBM),
# :neural_network => :(@load NeuralNetworkClassifier pkg = MLJFlux),
),
)
include("model_traits.jl")