diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 223ccc0..e9b4a8d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,6 +19,7 @@ jobs: fail-fast: false matrix: version: + - '1.6' - '1.7' - '1.8' - '1.9' diff --git a/Project.toml b/Project.toml index 040764f..c5434b2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ConformalPrediction" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" authors = ["Patrick Altmeyer"] -version = "0.1.8" +version = "0.1.9" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -26,6 +26,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] CategoricalArrays = "0.10" ChainRules = "1.49.0" +ComputationalResources = "0.3" Flux = "0.13.16, 0.14" MLJBase = "0.20, 0.21" MLJEnsembles = "0.3.3" @@ -34,7 +35,9 @@ MLJModelInterface = "1" MLUtils = "0.4.2" NaturalSort = "1" Plots = "1" +ProgressMeter = "1" StatsBase = "0.33, 0.34.0" +Tables = "1" julia = "1.7, 1.8, 1.9" [extras] diff --git a/README.md b/README.md index c9709e2..0f75edc 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ![](dev/logo/wide_logo.png) -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/dev/) [![Build Status](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![ColPrac: Contributor’s Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet.png)](https://github.com/SciML/ColPrac) [![Twitter Badge](https://img.shields.io/twitter/url/https/twitter.com/paltmey.svg?style=social&label=Follow%20%40paltmey)](https://twitter.com/paltmey) +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/dev/) [![Build Status](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![License](https://img.shields.io/github/license/juliatrustworthyai/ConformalPrediction.jl)](LICENSE) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/ConformalPrediction/.png)](https://pkgs.genieframework.com?packages=ConformalPrediction) `ConformalPrediction.jl` is a package for Predictive Uncertainty Quantification (UQ) through Conformal Prediction (CP) in Julia. It is designed to work with supervised models trained in [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) (Blaom et al. 2020). Conformal Prediction is easy-to-understand, easy-to-use and model-agnostic and it works under minimal distributional assumptions. @@ -71,7 +71,7 @@ X = reshape(X, :, 1) # Outputs: noise = 0.5 -fun(X) = X * sin(X) +fun(X) = sin(X) ε = randn(N) .* noise y = @.(fun(X)) + ε y = vec(y) @@ -111,11 +111,11 @@ ŷ[1:show_first] ``` 5-element Vector{Tuple{Float64, Float64}}: - (-0.40997718991694765, 1.449009293726001) - (0.8484810430118421, 2.7074675266547907) - (0.547852151594671, 2.4068386352376194) - (-0.022697652913589494, 1.8362888307293592) - (0.07435130847990101, 1.9333377921228496) + (0.0458889297242715, 1.9182762960257687) + (-1.9174452847238976, -0.04505791842240037) + (-1.2544275358451678, 0.6179598304563294) + (-0.2818835218505735, 1.5905038444509236) + (0.01299565032151917, 1.8853830166230163) For simple models like this one, we can call a custom `Plots` recipe on our instance, fit result and data to generate the chart below: @@ -138,8 +138,6 @@ println("Empirical coverage: $(round(_eval.measurement[1], digits=3))") println("SSC: $(round(_eval.measurement[2], digits=3))") ``` - Started! - PerformanceEvaluation object with these fields: measure, operation, measurement, per_fold, per_observation, fitted_params_per_fold, @@ -148,11 +146,14 @@ println("SSC: $(round(_eval.measurement[2], digits=3))") ┌──────────────────────────────────────────────┬───────────┬─────────────┬────── │ measure │ operation │ measurement │ 1.9 ⋯ ├──────────────────────────────────────────────┼───────────┼─────────────┼────── - │ ConformalPrediction.emp_coverage │ predict │ 0.945 │ 0.0 ⋯ - │ ConformalPrediction.size_stratified_coverage │ predict │ 0.945 │ 0.0 ⋯ + │ ConformalPrediction.emp_coverage │ predict │ 0.948 │ 0.0 ⋯ + │ ConformalPrediction.size_stratified_coverage │ predict │ 0.948 │ 0.0 ⋯ └──────────────────────────────────────────────┴───────────┴─────────────┴────── 2 columns omitted + Empirical coverage: 0.948 + SSC: 0.948 + ## 📚 Read on If after reading the usage example above you are just left with more questions about the topic, that’s normal. Below we have have collected a number of further resources to help you get started with this package and the topic itself: @@ -231,7 +232,7 @@ There is also a simple `Plots.jl` recipe that can be used to inspect the set siz bar(mach.model, mach.fitresult, X) ``` -![](README_files/figure-commonmark/cell-11-output-1.svg) +![](README_files/figure-commonmark/cell-12-output-1.svg) ## 🛠 Contribute diff --git a/README.qmd b/README.qmd index feac2e7..2f3e6b8 100644 --- a/README.qmd +++ b/README.qmd @@ -22,7 +22,7 @@ jupyter: julia-1.9 [![Build Status](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) -[![Twitter Badge](https://img.shields.io/twitter/url/https/twitter.com/paltmey.svg?style=social&label=Follow%20%40paltmey)](https://twitter.com/paltmey) + [![License](https://img.shields.io/github/license/juliatrustworthyai/ConformalPrediction.jl)](LICENSE) + [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/ConformalPrediction/)](https://pkgs.genieframework.com?packages=ConformalPrediction) {{< include docs/src/_intro.qmd >}} \ No newline at end of file diff --git a/README_files/figure-commonmark/cell-12-output-1.svg b/README_files/figure-commonmark/cell-12-output-1.svg new file mode 100644 index 0000000..0ea75b9 --- /dev/null +++ b/README_files/figure-commonmark/cell-12-output-1.svg @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/README_files/figure-commonmark/cell-7-output-1.svg b/README_files/figure-commonmark/cell-7-output-1.svg index 1db8c0b..b253af4 100644 --- a/README_files/figure-commonmark/cell-7-output-1.svg +++ b/README_files/figure-commonmark/cell-7-output-1.svg @@ -1,294 +1,288 @@ - + - + - + - + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/conformal_models/ConformalTraining/inductive_classification.jl b/src/conformal_models/ConformalTraining/inductive_classification.jl index 8a76eee..4f4e72d 100644 --- a/src/conformal_models/ConformalTraining/inductive_classification.jl +++ b/src/conformal_models/ConformalTraining/inductive_classification.jl @@ -18,7 +18,7 @@ function ConformalPrediction.score( ) X = permutedims(matrix(X)) probas = permutedims(fitresult[1](X)) - scores = @.(conf_model.heuristic(probas)) + scores = @.(conf_model.heuristic(y, probas)) if isnothing(y) return scores else @@ -46,7 +46,7 @@ function ConformalPrediction.score( p -> mean(p; dims=ndims(p)) |> p -> MLUtils.unstack(p; dims=ndims(p))[1] |> p -> permutedims(p) - scores = @.(conf_model.heuristic(probas)) + scores = @.(conf_model.heuristic(y, probas)) if isnothing(y) return scores else diff --git a/src/conformal_models/conformal_models.jl b/src/conformal_models/conformal_models.jl index 7c1068d..7899d07 100644 --- a/src/conformal_models/conformal_models.jl +++ b/src/conformal_models/conformal_models.jl @@ -16,6 +16,7 @@ const ConformalModel = Union{ } include("utils.jl") +include("heuristics.jl") include("plotting.jl") # Main API call to wrap model: diff --git a/src/conformal_models/heuristics.jl b/src/conformal_models/heuristics.jl new file mode 100644 index 0000000..035f3d2 --- /dev/null +++ b/src/conformal_models/heuristics.jl @@ -0,0 +1,13 @@ +""" + minus_softmax(y,ŷ) + +Computes `1.0 - ŷ` where `ŷ` is the softmax output for a given class. +""" +minus_softmax(y, ŷ) = 1.0 - ŷ + +""" + absolute_error(y,ŷ) + +Computes `abs(y - ŷ)` where `ŷ` is the predicted value. +""" +absolute_error(y, ŷ) = abs(y - ŷ) diff --git a/src/conformal_models/inductive_classification.jl b/src/conformal_models/inductive_classification.jl index 02f3752..9d6b82b 100644 --- a/src/conformal_models/inductive_classification.jl +++ b/src/conformal_models/inductive_classification.jl @@ -41,7 +41,7 @@ end function SimpleInductiveClassifier( model::Supervised; coverage::AbstractFloat=0.95, - heuristic::Function=f(p̂) = 1.0 - p̂, + heuristic::Function=minus_softmax, train_ratio::AbstractFloat=0.5, ) return SimpleInductiveClassifier(model, coverage, nothing, heuristic, train_ratio) @@ -62,7 +62,7 @@ function score( p̂ = reformat_mlj_prediction(MMI.predict(conf_model.model, fitresult, X)) L = p̂.decoder.classes probas = pdf(p̂, L) - scores = @.(conf_model.heuristic(probas)) + scores = @.(conf_model.heuristic(y, probas)) if isnothing(y) return scores else @@ -141,7 +141,7 @@ end function AdaptiveInductiveClassifier( model::Supervised; coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = 1.0 - ŷ, + heuristic::Function=minus_softmax, train_ratio::AbstractFloat=0.5, ) return AdaptiveInductiveClassifier(model, coverage, nothing, heuristic, train_ratio) diff --git a/src/conformal_models/inductive_regression.jl b/src/conformal_models/inductive_regression.jl index c2ce341..7a72037 100644 --- a/src/conformal_models/inductive_regression.jl +++ b/src/conformal_models/inductive_regression.jl @@ -10,7 +10,7 @@ end function SimpleInductiveRegressor( model::Supervised; coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + heuristic::Function=absolute_error, train_ratio::AbstractFloat=0.5, ) return SimpleInductiveRegressor(model, coverage, nothing, heuristic, train_ratio) diff --git a/src/conformal_models/transductive_classification.jl b/src/conformal_models/transductive_classification.jl index f188aca..350bef7 100644 --- a/src/conformal_models/transductive_classification.jl +++ b/src/conformal_models/transductive_classification.jl @@ -8,7 +8,7 @@ mutable struct NaiveClassifier{Model<:Supervised} <: ConformalProbabilisticSet end function NaiveClassifier( - model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=f(y, ŷ) = 1.0 - ŷ + model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=minus_softmax ) return NaiveClassifier(model, coverage, nothing, heuristic) end diff --git a/src/conformal_models/transductive_regression.jl b/src/conformal_models/transductive_regression.jl index 9ee2313..864604e 100644 --- a/src/conformal_models/transductive_regression.jl +++ b/src/conformal_models/transductive_regression.jl @@ -13,9 +13,7 @@ mutable struct NaiveRegressor{Model<:Supervised} <: ConformalInterval end function NaiveRegressor( - model::Supervised; - coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=absolute_error ) return NaiveRegressor(model, coverage, nothing, heuristic) end @@ -81,9 +79,7 @@ mutable struct JackknifeRegressor{Model<:Supervised} <: ConformalInterval end function JackknifeRegressor( - model::Supervised; - coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=absolute_error ) return JackknifeRegressor(model, coverage, nothing, heuristic) end @@ -163,9 +159,7 @@ mutable struct JackknifePlusRegressor{Model<:Supervised} <: ConformalInterval end function JackknifePlusRegressor( - model::Supervised; - coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=absolute_error ) return JackknifePlusRegressor(model, coverage, nothing, heuristic) end @@ -254,9 +248,7 @@ mutable struct JackknifeMinMaxRegressor{Model<:Supervised} <: ConformalInterval end function JackknifeMinMaxRegressor( - model::Supervised; - coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + model::Supervised; coverage::AbstractFloat=0.95, heuristic::Function=absolute_error ) return JackknifeMinMaxRegressor(model, coverage, nothing, heuristic) end @@ -347,7 +339,7 @@ end function CVPlusRegressor( model::Supervised; coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + heuristic::Function=absolute_error, cv::MLJBase.CV=MLJBase.CV(), ) return CVPlusRegressor(model, coverage, nothing, heuristic, cv) @@ -452,7 +444,7 @@ end function CVMinMaxRegressor( model::Supervised; coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + heuristic::Function=absolute_error, cv::MLJBase.CV=MLJBase.CV(), ) return CVMinMaxRegressor(model, coverage, nothing, heuristic, cv) @@ -580,7 +572,7 @@ end function JackknifePlusAbRegressor( model::Supervised; coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + heuristic::Function=absolute_error, nsampling::Int=30, sample_size::AbstractFloat=0.5, replacement::Bool=true, @@ -686,7 +678,7 @@ end function JackknifePlusAbMinMaxRegressor( model::Supervised; coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + heuristic::Function=absolute_error, nsampling::Int=30, sample_size::AbstractFloat=0.5, replacement::Bool=true, @@ -789,7 +781,7 @@ end function TimeSeriesRegressorEnsembleBatch( model::Supervised; coverage::AbstractFloat=0.95, - heuristic::Function=f(y, ŷ) = abs(y - ŷ), + heuristic::Function=absolute_error, nsampling::Int=50, sample_size::AbstractFloat=0.3, aggregate::Union{Symbol,String}="mean",