diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 223ccc0..e9b4a8d 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -19,6 +19,7 @@ jobs:
fail-fast: false
matrix:
version:
+ - '1.6'
- '1.7'
- '1.8'
- '1.9'
diff --git a/Project.toml b/Project.toml
index 040764f..c5434b2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "ConformalPrediction"
uuid = "98bfc277-1877-43dc-819b-a3e38c30242f"
authors = ["Patrick Altmeyer"]
-version = "0.1.8"
+version = "0.1.9"
[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -26,6 +26,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
[compat]
CategoricalArrays = "0.10"
ChainRules = "1.49.0"
+ComputationalResources = "0.3"
Flux = "0.13.16, 0.14"
MLJBase = "0.20, 0.21"
MLJEnsembles = "0.3.3"
@@ -34,7 +35,9 @@ MLJModelInterface = "1"
MLUtils = "0.4.2"
NaturalSort = "1"
Plots = "1"
+ProgressMeter = "1"
StatsBase = "0.33, 0.34.0"
+Tables = "1"
julia = "1.7, 1.8, 1.9"
[extras]
diff --git a/README.md b/README.md
index c9709e2..0f75edc 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
![](dev/logo/wide_logo.png)
-[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/dev/) [![Build Status](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![ColPrac: Contributor’s Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet.png)](https://github.com/SciML/ColPrac) [![Twitter Badge](https://img.shields.io/twitter/url/https/twitter.com/paltmey.svg?style=social&label=Follow%20%40paltmey)](https://twitter.com/paltmey)
+[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/dev/) [![Build Status](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![License](https://img.shields.io/github/license/juliatrustworthyai/ConformalPrediction.jl)](LICENSE) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/ConformalPrediction/.png)](https://pkgs.genieframework.com?packages=ConformalPrediction)
`ConformalPrediction.jl` is a package for Predictive Uncertainty Quantification (UQ) through Conformal Prediction (CP) in Julia. It is designed to work with supervised models trained in [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) (Blaom et al. 2020). Conformal Prediction is easy-to-understand, easy-to-use and model-agnostic and it works under minimal distributional assumptions.
@@ -71,7 +71,7 @@ X = reshape(X, :, 1)
# Outputs:
noise = 0.5
-fun(X) = X * sin(X)
+fun(X) = sin(X)
ε = randn(N) .* noise
y = @.(fun(X)) + ε
y = vec(y)
@@ -111,11 +111,11 @@ ŷ[1:show_first]
```
5-element Vector{Tuple{Float64, Float64}}:
- (-0.40997718991694765, 1.449009293726001)
- (0.8484810430118421, 2.7074675266547907)
- (0.547852151594671, 2.4068386352376194)
- (-0.022697652913589494, 1.8362888307293592)
- (0.07435130847990101, 1.9333377921228496)
+ (0.0458889297242715, 1.9182762960257687)
+ (-1.9174452847238976, -0.04505791842240037)
+ (-1.2544275358451678, 0.6179598304563294)
+ (-0.2818835218505735, 1.5905038444509236)
+ (0.01299565032151917, 1.8853830166230163)
For simple models like this one, we can call a custom `Plots` recipe on our instance, fit result and data to generate the chart below:
@@ -138,8 +138,6 @@ println("Empirical coverage: $(round(_eval.measurement[1], digits=3))")
println("SSC: $(round(_eval.measurement[2], digits=3))")
```
- Started!
-
PerformanceEvaluation object with these fields:
measure, operation, measurement, per_fold,
per_observation, fitted_params_per_fold,
@@ -148,11 +146,14 @@ println("SSC: $(round(_eval.measurement[2], digits=3))")
┌──────────────────────────────────────────────┬───────────┬─────────────┬──────
│ measure │ operation │ measurement │ 1.9 ⋯
├──────────────────────────────────────────────┼───────────┼─────────────┼──────
- │ ConformalPrediction.emp_coverage │ predict │ 0.945 │ 0.0 ⋯
- │ ConformalPrediction.size_stratified_coverage │ predict │ 0.945 │ 0.0 ⋯
+ │ ConformalPrediction.emp_coverage │ predict │ 0.948 │ 0.0 ⋯
+ │ ConformalPrediction.size_stratified_coverage │ predict │ 0.948 │ 0.0 ⋯
└──────────────────────────────────────────────┴───────────┴─────────────┴──────
2 columns omitted
+ Empirical coverage: 0.948
+ SSC: 0.948
+
## 📚 Read on
If after reading the usage example above you are just left with more questions about the topic, that’s normal. Below we have have collected a number of further resources to help you get started with this package and the topic itself:
@@ -231,7 +232,7 @@ There is also a simple `Plots.jl` recipe that can be used to inspect the set siz
bar(mach.model, mach.fitresult, X)
```
-![](README_files/figure-commonmark/cell-11-output-1.svg)
+![](README_files/figure-commonmark/cell-12-output-1.svg)
## 🛠 Contribute
diff --git a/README.qmd b/README.qmd
index feac2e7..2f3e6b8 100644
--- a/README.qmd
+++ b/README.qmd
@@ -22,7 +22,7 @@ jupyter: julia-1.9
[![Build Status](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
-[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
-[![Twitter Badge](https://img.shields.io/twitter/url/https/twitter.com/paltmey.svg?style=social&label=Follow%20%40paltmey)](https://twitter.com/paltmey)
+ [![License](https://img.shields.io/github/license/juliatrustworthyai/ConformalPrediction.jl)](LICENSE)
+ [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/ConformalPrediction/)](https://pkgs.genieframework.com?packages=ConformalPrediction)
{{< include docs/src/_intro.qmd >}}
\ No newline at end of file
diff --git a/README_files/figure-commonmark/cell-12-output-1.svg b/README_files/figure-commonmark/cell-12-output-1.svg
new file mode 100644
index 0000000..0ea75b9
--- /dev/null
+++ b/README_files/figure-commonmark/cell-12-output-1.svg
@@ -0,0 +1,49 @@
+
+
diff --git a/README_files/figure-commonmark/cell-7-output-1.svg b/README_files/figure-commonmark/cell-7-output-1.svg
index 1db8c0b..b253af4 100644
--- a/README_files/figure-commonmark/cell-7-output-1.svg
+++ b/README_files/figure-commonmark/cell-7-output-1.svg
@@ -1,294 +1,288 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/conformal_models/ConformalTraining/inductive_classification.jl b/src/conformal_models/ConformalTraining/inductive_classification.jl
index 8a76eee..4f4e72d 100644
--- a/src/conformal_models/ConformalTraining/inductive_classification.jl
+++ b/src/conformal_models/ConformalTraining/inductive_classification.jl
@@ -18,7 +18,7 @@ function ConformalPrediction.score(
)
X = permutedims(matrix(X))
probas = permutedims(fitresult[1](X))
- scores = @.(conf_model.heuristic(probas))
+ scores = @.(conf_model.heuristic(y, probas))
if isnothing(y)
return scores
else
@@ -46,7 +46,7 @@ function ConformalPrediction.score(
p ->
mean(p; dims=ndims(p)) |>
p -> MLUtils.unstack(p; dims=ndims(p))[1] |> p -> permutedims(p)
- scores = @.(conf_model.heuristic(probas))
+ scores = @.(conf_model.heuristic(y, probas))
if isnothing(y)
return scores
else
diff --git a/src/conformal_models/conformal_models.jl b/src/conformal_models/conformal_models.jl
index 7c1068d..7899d07 100644
--- a/src/conformal_models/conformal_models.jl
+++ b/src/conformal_models/conformal_models.jl
@@ -16,6 +16,7 @@ const ConformalModel = Union{
}
include("utils.jl")
+include("heuristics.jl")
include("plotting.jl")
# Main API call to wrap model:
diff --git a/src/conformal_models/heuristics.jl b/src/conformal_models/heuristics.jl
new file mode 100644
index 0000000..035f3d2
--- /dev/null
+++ b/src/conformal_models/heuristics.jl
@@ -0,0 +1,13 @@
+"""
+ minus_softmax(y,ŷ)
+
+Computes `1.0 - ŷ` where `ŷ` is the softmax output for a given class.
+"""
+minus_softmax(y, ŷ) = 1.0 - ŷ
+
+"""
+ absolute_error(y,ŷ)
+
+Computes `abs(y - ŷ)` where `ŷ` is the predicted value.
+"""
+absolute_error(y, ŷ) = abs(y - ŷ)
diff --git a/src/conformal_models/inductive_classification.jl b/src/conformal_models/inductive_classification.jl
index 02f3752..9d6b82b 100644
--- a/src/conformal_models/inductive_classification.jl
+++ b/src/conformal_models/inductive_classification.jl
@@ -41,7 +41,7 @@ end
function SimpleInductiveClassifier(
model::Supervised;
coverage::AbstractFloat=0.95,
- heuristic::Function=f(p̂) = 1.0 - p̂,
+ heuristic::Function=minus_softmax,
train_ratio::AbstractFloat=0.5,
)
return SimpleInductiveClassifier(model, coverage, nothing, heuristic, train_ratio)
@@ -62,7 +62,7 @@ function score(
p̂ = reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, X))
L = p̂.decoder.classes
probas = pdf(p̂, L)
- scores = @.(conf_model.heuristic(probas))
+ scores = @.(conf_model.heuristic(y, probas))
if isnothing(y)
return scores
else
@@ -141,7 +141,7 @@ end
function AdaptiveInductiveClassifier(
model::Supervised;
coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = 1.0 - ŷ,
+ heuristic::Function=minus_softmax,
train_ratio::AbstractFloat=0.5,
)
return AdaptiveInductiveClassifier(model, coverage, nothing, heuristic, train_ratio)
diff --git a/src/conformal_models/inductive_regression.jl b/src/conformal_models/inductive_regression.jl
index c2ce341..7a72037 100644
--- a/src/conformal_models/inductive_regression.jl
+++ b/src/conformal_models/inductive_regression.jl
@@ -10,7 +10,7 @@ end
function SimpleInductiveRegressor(
model::Supervised;
coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ heuristic::Function=absolute_error,
train_ratio::AbstractFloat=0.5,
)
return SimpleInductiveRegressor(model, coverage, nothing, heuristic, train_ratio)
diff --git a/src/conformal_models/transductive_classification.jl b/src/conformal_models/transductive_classification.jl
index f188aca..350bef7 100644
--- a/src/conformal_models/transductive_classification.jl
+++ b/src/conformal_models/transductive_classification.jl
@@ -8,7 +8,7 @@ mutable struct NaiveClassifier{Model<:Supervised} <: ConformalProbabilisticSet
end
function NaiveClassifier(
- model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=f(y, ŷ) = 1.0 - ŷ
+ model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=minus_softmax
)
return NaiveClassifier(model, coverage, nothing, heuristic)
end
diff --git a/src/conformal_models/transductive_regression.jl b/src/conformal_models/transductive_regression.jl
index 9ee2313..864604e 100644
--- a/src/conformal_models/transductive_regression.jl
+++ b/src/conformal_models/transductive_regression.jl
@@ -13,9 +13,7 @@ mutable struct NaiveRegressor{Model<:Supervised} <: ConformalInterval
end
function NaiveRegressor(
- model::Supervised;
- coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=absolute_error
)
return NaiveRegressor(model, coverage, nothing, heuristic)
end
@@ -81,9 +79,7 @@ mutable struct JackknifeRegressor{Model<:Supervised} <: ConformalInterval
end
function JackknifeRegressor(
- model::Supervised;
- coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=absolute_error
)
return JackknifeRegressor(model, coverage, nothing, heuristic)
end
@@ -163,9 +159,7 @@ mutable struct JackknifePlusRegressor{Model<:Supervised} <: ConformalInterval
end
function JackknifePlusRegressor(
- model::Supervised;
- coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=absolute_error
)
return JackknifePlusRegressor(model, coverage, nothing, heuristic)
end
@@ -254,9 +248,7 @@ mutable struct JackknifeMinMaxRegressor{Model<:Supervised} <: ConformalInterval
end
function JackknifeMinMaxRegressor(
- model::Supervised;
- coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=absolute_error
)
return JackknifeMinMaxRegressor(model, coverage, nothing, heuristic)
end
@@ -347,7 +339,7 @@ end
function CVPlusRegressor(
model::Supervised;
coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ heuristic::Function=absolute_error,
cv::MLJBase.CV=MLJBase.CV(),
)
return CVPlusRegressor(model, coverage, nothing, heuristic, cv)
@@ -452,7 +444,7 @@ end
function CVMinMaxRegressor(
model::Supervised;
coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ heuristic::Function=absolute_error,
cv::MLJBase.CV=MLJBase.CV(),
)
return CVMinMaxRegressor(model, coverage, nothing, heuristic, cv)
@@ -580,7 +572,7 @@ end
function JackknifePlusAbRegressor(
model::Supervised;
coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ heuristic::Function=absolute_error,
nsampling::Int=30,
sample_size::AbstractFloat=0.5,
replacement::Bool=true,
@@ -686,7 +678,7 @@ end
function JackknifePlusAbMinMaxRegressor(
model::Supervised;
coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ heuristic::Function=absolute_error,
nsampling::Int=30,
sample_size::AbstractFloat=0.5,
replacement::Bool=true,
@@ -789,7 +781,7 @@ end
function TimeSeriesRegressorEnsembleBatch(
model::Supervised;
coverage::AbstractFloat=0.95,
- heuristic::Function=f(y, ŷ) = abs(y - ŷ),
+ heuristic::Function=absolute_error,
nsampling::Int=50,
sample_size::AbstractFloat=0.3,
aggregate::Union{Symbol,String}="mean",