Skip to content

JuliaTrustworthyAI/ConformalPrediction.jl

Repository files navigation

ConformalPrediction

Stable Dev Build Status Coverage Code Style: Blue ColPrac: Contributorโ€™s Guide on Collaborative Practices for Community Packages

ConformalPrediction.jl is a package for Uncertainty Quantification (UQ) through Conformal Prediction (CP) in Julia. It is designed to work with supervised models trained in MLJ. Conformal Prediction is distribution-free, easy-to-understand, easy-to-use and model-agnostic.

Installation ๐Ÿšฉ

You can install the first stable release from the general registry:

using Pkg
Pkg.add("ConformalPrediction")

The development version can be installed as follows:

using Pkg
Pkg.add(url="https://github.com/pat-alt/ConformalPrediction.jl")

Status ๐Ÿ”

This package is in its very early stages of development and therefore still subject to changes to the core architecture. The following approaches have been implemented in the development version:

Regression:

  • Inductive
  • Naive Transductive
  • Jackknife
  • Jackknife+
  • Jackknife-minmax
  • CV+
  • CV-minmax

Classification:

  • Inductive (LABEL (Sadinle, Lei, and Wasserman 2019))
  • Adaptive Inductive

I have only tested it for a few of the supervised models offered by MLJ.

Usage Example ๐Ÿ”

To illustrate the intended use of the package, letโ€™s have a quick look at a simple regression problem. Using MLJ we first generate some synthetic data and then determine indices for our training, calibration and test data:

using MLJ
X, y = MLJ.make_regression(1000, 2)
train, test = partition(eachindex(y), 0.4, 0.4)

We then import a decision tree (DecisionTree) following the standard MLJ procedure.

DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
model = DecisionTreeRegressor() 

To turn our conventional model into a conformal model, we just need to declare it as such by using conformal_model wrapper function. The generated conformal model instance can wrapped in data to create a machine. Finally, we proceed by fitting the machine on training data using the generic fit! method:

using ConformalPrediction
conf_model = conformal_model(model)
mach = machine(conf_model, X, y)
fit!(mach, rows=train)

Predictions can then be computed using the generic predict method. The code below produces predictions for the first n samples. Each tuple contains the lower and upper bound for the prediction interval.

n = 10
Xtest = selectrows(X, first(test,n))
ytest = y[first(test,n)]
predict(mach, Xtest)
โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚                                                                โ”‚
โ”‚       (1)   ([-1.755717205142032], [0.1336793920749545])       โ”‚
โ”‚       (2)   ([-2.725152276022311], [-0.8357556788053242])      โ”‚
โ”‚       (3)   ([1.7996228430066177], [3.6890194402236043])       โ”‚
โ”‚       (4)   ([-2.090812733251826], [-0.20141613603483965])     โ”‚
โ”‚       (5)   ([0.9599243814807339], [2.8493209786977207])       โ”‚
โ”‚       (6)   ([-0.6383470472809984], [1.2510495499359882])      โ”‚
โ”‚       (7)   ([1.6779292744150438], [3.5673258716320304])       โ”‚
โ”‚       (8)   ([0.08317330201878925], [1.9725698992357759])      โ”‚
โ”‚       (9)   ([-0.12150563172572815], [1.7678909654912585])     โ”‚
โ”‚      (10)   ([-1.1611481858237893], [0.7282484113931974])      โ”‚
โ”‚                                                                โ”‚
โ”‚                                                                โ”‚
โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ 10 items โ”€โ”€โ”€โ•ฏ

Contribute ๐Ÿ› 

Contributions are welcome! Please follow the SciML ColPrac guide.

References ๐ŸŽ“

Sadinle, Mauricio, Jing Lei, and Larry Wasserman. 2019. โ€œLeast Ambiguous Set-Valued Classifiers with Bounded Error Levels.โ€ Journal of the American Statistical Association 114 (525): 223โ€“34.