This code provides an implementation of Calibrated Selective Classification.
(Some elements still in progress.)
Selective classification allows models to abstain from making predictions (e.g., say "I don't know") when in doubt in order to obtain better effective accuracy. While typical selective models can succeed at producing more accurate predictions on average, they may still allow for wrong predictions that have high confidence, or skip correct predictions that have low confidence. Providing calibrated uncertainty estimates alongside predictions---probabilities that correspond to true frequencies---can be as important as having predictions that are simply accurate on average. Uncertainty estimates, however, can sometimes be unreliable. This repository implements a new approach to calibrated selective classification (CSC), where a selector is learned to reject examples with "uncertain" uncertainties. The goal is to make predictions with well-calibrated uncertainty estimates over the distribution of accepted examples, a property called selective calibration.
We provide instructions to install CSC within a conda environment running Python 3.8.
conda create -n csc python=3.8 && conda activate csc
Install PyTorch following the instructions from pytorch.org.
Install all other requirements via pip install -r requirements.txt
.
All commands should now be run from the root directory of this repository.
All pre-trained
./download.sh cifar
./download.sh imagenet
See individual repositories for instructions on task-specific training, data downloading, and preprocessing.
Note The following steps assume that training and testing data have already been preprocessed into InputDataset's, and that the calibration and validation data have been processed into BatchedInputdataset's. This format is described below.
An InputDataset
is a namedtuple with the following fields:
-
input_features
: The representation for the input$x$ . For example, the last layer representation of$f(x)$ derived for all inputs$x$ in the dataset. This is an array of size[num_examples, num_features]
. -
output_probs
: The prediction$p_\theta(y|x)$ . This is an array of size[num_examples, num_classes]
(for binary problems, takenum_classes
= 2). -
confidences
: The confidence estimate (typically$p_\theta(y = 1 | x)$ for binary problems, or$\max p_\theta(y | x)$ for multi-class problems). This is an array of size[num_examples]
. -
labels
: The binary (or binarized) label (typically$y$ for binary problems, or$y = \arg\max p_\theta(y | x)$ for multi-class problems). This is an array of size[num_examples]
.
A BatchedInputDataset
is exactly the same as the above InputDataset
, with the difference that each portion of the data has an extra leading dimension for the perturbation index. Specifically, the calibration and validation data are comprised of perturbed batches of data, in which a perturbation BatchedInputDataset
is therefore essentially a concatenation of many InputDataset
s. The sizes of each field of a BatchedInputDataset
are therefore:
input_features
:[num_perturbations, num_examples_per_perturbation, num_features]
.output_probs
:[num_perturbations, num_examples_per_perturbation, num_classes]
.confidences
:[num_perturbations, num_examples_per_perturbation]
.labels
:[num_perturbations, num_examples_per_perturbation]
.
(Optional) To check the initial calibration error without any selection, you can run:
python bin/tools/check_calibration_error.py \
--datasets <paths to saved InputDatasts and/or BatchedInputDatasets>
To generate the meta-features, run
python bin/tools/generate_meta_features.py \
--train-dataset <path to saved InputDataset> \
--cal-dataset <path to saved BatchedInputDataset> \
--val-dataset <path to saved BatchedInputDataset> \
--test-datasets <paths to saved InputDatasets>
This will save each calibration, validation, and testing file as a new InputDataset
(or BatchedInputDataset
), with the input_features
field replaced with the derived meta features.
Note See section 4.5 of the paper for a description of the chosen meta features.
To train a soft selective model
python bin/tools/train_selective_model.py \
--cal-dataset <path to saved BatchedInputDataset> \
--val-dataset <path to saved BatchedInputDataset>
The binary predictor
Pass --model-dir
to specify a target directory to save the model (otherwise a temp directory will be used automatically).
Note All of the subsequent evaluation steps simultaneously calibrate and evaluate the selector
g(X)
such that it acheives the target coverage. To derive a threshold for the soft selector, runpython bin/tools/calibrate_selector_threshold.py \ --model-file <path to save SelectiveNet checkpoint> \ --calibration-dataset <path to saved (unlabeled) InputDataset>
To make predictions (at a target coverage level
python bin/tools/run_selective_model.py \
--model-file <path to saved SelectiveNet checkpoint> \
--input-dataset <path to saved InputDataset> \
--calibration-dataset <path to saved (unlabeled) InputDataset> \
--threshold <threshold for g> \
--coverage <coverage level \xi> \
--output-file <path to output file>
If the coverage level
Warning If the
calibration_dataset
argument is not given, then the threshold for making predictions the soft selector will be computed on theinput_dataset
. If, however, thethreshold
argument is given, then both thecalibration_dataset
andcoverage_level
arguments will be ignored, and the model will make predictions using the given threshold.
To evaluate the selective predictor in terms of selective calibration error, run:
python bin/tools/evaluate_selective_model.py \
--model-files <paths to saved SelectiveNet checkpoints> \
--coverage <coverage level \xi> \
--output-file <path to output file>
If the coverage level --bootstraps
argument). Note that this will compute all predictions from scratch.
We provide some Jupyter notebook scripts for visualizing predictions in the notebooks folder.
If you find this work useful, please cite our TMLR paper:
@article{fisch2022selective,
title={Calibrated Selective Classification},
author={Adam Fisch and Tommi Jaakkola and Regina Barzilay},
journal={Transactions on Machine Learning Research (TMLR)},
month={12},
year={2022},
url={https://openreview.net/forum?id=zFhNBs8GaV},
}