-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
90 average calibration functions in utils.jl #97
Closed
Closed
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
733312a
function empirical_frequency
pasq-cat 25ea642
fixed the docstring.
pasq-cat c290ed8
added sharpness and binary classification. i have yet to test them pr…
pasq-cat 4ff22f4
added trapz to the list of dependencies.
pasq-cat 6a22210
added Distributions to theproject
pasq-cat df3d60d
working version
pasq-cat 09f25e8
ops forgot to add sharpness for the classification case
pasq-cat 07b318f
working release.. changed changelog, glm_predictive_distribution, pr…
pasq-cat eafa7bd
function empirical_frequency
pasq-cat f66e08e
fixed the docstring.
pasq-cat 5355281
added sharpness and binary classification. i have yet to test them pr…
pasq-cat 2efaa99
added trapz to the list of dependencies.
pasq-cat 26643ee
added Distributions to theproject
pasq-cat b79ca39
working version
pasq-cat 0d71736
ops forgot to add sharpness for the classification case
pasq-cat 5f772cf
working release.. changed changelog, glm_predictive_distribution, pr…
pasq-cat d146d1d
Merge branch '90-average-calibration-in-utilsjl' of https://github.co…
pasq-cat 7af9378
changed docstrings in predicting.jl
pasq-cat 2c42236
fixed glm_predictive_distribution
pasq-cat 9d67ddc
Update src/utils.jl
pasq-cat 9f07583
Update src/utils.jl
pasq-cat f81d226
Update src/utils.jl
pasq-cat 6cdc503
Update src/baselaplace/predicting.jl
pasq-cat 89bb19b
Update src/baselaplace/predicting.jl
pasq-cat 6fe01a2
JuliaFormatter
pasq-cat 0bba488
fixed docstrings
pasq-cat 8311de3
made docstrings a lil bit shorter
pasq-cat 7837333
docstrings again (added output)
pasq-cat b0518b2
fixed binary classification case, exported function from utils.
pasq-cat 6a9ee1b
juliaformatter
pasq-cat 203513d
add n_bins as argument to functions
pasq-cat dce9bdb
ops forgot default value
pasq-cat b906c3b
ops forgot default value and removed a line
pasq-cat 2059bed
Merge branch '90-average-calibration-in-utilsjl' of https://github.co…
pasq-cat 3258618
juliaformatter----
pasq-cat c86dc25
fixed small error in pred_avg
pasq-cat 3d2ebd6
fixed error in empirical_frequency_regression
pasq-cat 4ab04f6
Update src/utils.jl
pasq-cat 267b8f4
docstrings fixes and predict update
pasq-cat d188daf
fixed typos
pasq-cat 270b70a
moved sharpness functions units tests in calibration.jl. changed run…
pasq-cat 3320063
more sharpness unit tests
pasq-cat 3750dbe
fixes and more unit tests
pasq-cat 39d4bdc
small stuff
pasq-cat 56c3b66
fix. there is still an issue with the shape of the input to use.
pasq-cat 908c804
fixed logit.md ,moved functions to new file, removed changes to predi…
pasq-cat f468803
removed calibration_plots.md
pasq-cat 459b2fe
test plot
pasq-cat File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Uncertainty Calibration | ||
## The issue of calibrated uncertainty distributions | ||
Bayesian methods offer a general framework for quantifying uncertainty. However, due to model misspecification and the use of approximate inference, Bayesian uncertainty estimates are often inaccurate: for example, a 90% credible interval may not contain the true outcome 90% of the time. A model is considered calibrated when uncertainty estimates, such as Bayesian credible intervals, accurately reflect the true likelihood of outcomes. In other words, a 90% credible interval is calibrated if it contains the true outcome approximately 90% of the time, thereby indicating the reliability and accuracy of the inference method. In other words, a good forecaster must be calibrated. Perfect calibration | ||
|
||
|
||
## Calibration Plots | ||
|
||
|
||
yadda yadda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
docs/src/tutorials/regression_files/figure-commonmark/miscalibration.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
using Statistics | ||
@doc raw""" | ||
empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=20) | ||
|
||
FOR REGRESSION MODELS. \ | ||
Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` and an array of predicted distributions, the function calculates the empirical frequency | ||
```math | ||
p^hat_j = {y_t|F_t(y_t)<= p_j, t= 1,....,T}/T, | ||
``` | ||
where ``T`` is the number of calibration points, ``p_j`` is the confidence level and ``F_t`` is the | ||
cumulative distribution function of the predicted distribution targeting ``y_t``. \ | ||
Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) | ||
|
||
Inputs: \ | ||
- `Y_cal`: a vector of values ``y_t``\ | ||
- `sampled_distributions`:a Vector{Vector{Float64}} of sampled distributions ``F(x_t)`` stacked row-wise.\ | ||
For example [rand(distr,50) for distr in LaplaceRedux.predict(la,X)] | ||
- `n_bins`: number of equally spaced bins to use.\ | ||
Outputs:\ | ||
- `counts`: an array cointaining the empirical frequencies for each quantile interval. | ||
""" | ||
function empirical_frequency_regression(Y_cal, sampled_distributions; n_bins::Int=20) | ||
if n_bins <= 0 | ||
throw(ArgumentError("n_bins must be a positive integer")) | ||
end | ||
n_edges = n_bins + 1 | ||
quantiles = collect(range(0; stop=1, length=n_edges)) | ||
quantiles_matrix = hcat( | ||
[quantile(samples, quantiles) for samples in sampled_distributions]... | ||
) | ||
n_rows = size(quantiles_matrix, 1) | ||
counts = Float64[] | ||
|
||
for i in 1:n_rows | ||
push!(counts, sum(Y_cal .<= quantiles_matrix[i, :]) / length(Y_cal)) | ||
end | ||
return counts | ||
end | ||
|
||
@doc raw""" | ||
sharpness_regression(sampled_distributions) | ||
|
||
FOR REGRESSION MODELS. \ | ||
Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` and an array of predicted distributions, the function calculates the | ||
sharpness of the predicted distributions, i.e., the average of the variances ``\sigma^2(F_t)`` predicted by the forecaster for each ``x_t``. \ | ||
source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) | ||
|
||
Inputs: \ | ||
- `sampled_distributions`: an array of sampled distributions ``F(x_t)`` stacked column-wise. \ | ||
Outputs: \ | ||
- `sharpness`: a scalar that measure the level of sharpness of the regressor | ||
""" | ||
function sharpness_regression(sampled_distributions) | ||
sharpness = mean(var.(sampled_distributions)) | ||
return sharpness | ||
end | ||
|
||
@doc raw""" | ||
empirical_frequency_classification(y_binary, sampled_distributions) | ||
|
||
FOR BINARY CLASSIFICATION MODELS.\ | ||
Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` let ``p_t= H(x_t)∈[0,1]`` be the forecasted probability. \ | ||
We group the ``p_t`` into intervals ``I_j`` for ``j= 1,2,...,m`` that form a partition of [0,1]. | ||
The function computes the observed average ``p_j= T^-1_j ∑_{t:p_t ∈ I_j} y_j`` in each interval ``I_j``. \ | ||
Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) | ||
|
||
Inputs: \ | ||
- `y_binary`: the array of outputs ``y_t`` numerically coded: 1 for the target class, 0 for the null class. \ | ||
- `sampled_distributions`: an array of sampled distributions stacked column-wise so that in the first row | ||
there is the probability for the target class ``y_1`` and in the second row the probability for the null class ``y_0``. \ | ||
- `n_bins`: number of equally spaced bins to use. | ||
|
||
Outputs: \ | ||
- `num_p_per_interval`: array with the number of probabilities falling within interval. \ | ||
- `emp_avg`: array with the observed empirical average per interval. \ | ||
- `bin_centers`: array with the centers of the bins. | ||
|
||
""" | ||
function empirical_frequency_binary_classification( | ||
y_binary, sampled_distributions; n_bins::Int=20 | ||
) | ||
if n_bins <= 0 | ||
throw(ArgumentError("n_bins must be a positive integer")) | ||
elseif !all(x -> x == 0 || x == 1, y_binary) | ||
throw(ArgumentError("y_binary must be an array of 0 and 1")) | ||
end | ||
#intervals boundaries | ||
n_edges = n_bins + 1 | ||
int_bds = collect(range(0; stop=1, length=n_edges)) | ||
#bin centers | ||
bin_centers = [(int_bds[i] + int_bds[i + 1]) / 2 for i in 1:(length(int_bds) - 1)] | ||
#initialize list for empirical averages per interval | ||
emp_avg = [] | ||
#initialize list for predicted averages per interval | ||
pred_avg = [] | ||
# initialize list of number of probabilities falling within each intervals | ||
num_p_per_interval = [] | ||
#list of the predicted probabilities for the target class | ||
class_probs = sampled_distributions[1, :] | ||
# iterate over the bins | ||
for j in 1:n_bins | ||
push!(num_p_per_interval, sum(int_bds[j] .< class_probs .< int_bds[j + 1])) | ||
if num_p_per_interval[j] == 0 | ||
push!(emp_avg, 0) | ||
push!(pred_avg, bin_centers[j]) | ||
|
||
else | ||
# find the indices fo all istances for which class_probs fall withing the j-th interval | ||
indices = findall(x -> int_bds[j] < x < int_bds[j + 1], class_probs) | ||
#compute the empirical average and saved it in emp_avg in the j-th position | ||
push!(emp_avg, 1 / num_p_per_interval[j] * sum(y_binary[indices])) | ||
#TO DO: maybe substitute to bin_Centers? | ||
push!(pred_avg, 1 / num_p_per_interval[j] * sum(class_probs[indices])) | ||
end | ||
end | ||
#return the tuple | ||
return (num_p_per_interval, emp_avg, bin_centers) | ||
end | ||
|
||
@doc raw""" | ||
sharpness_classification(y_binary,sampled_distributions) | ||
|
||
FOR BINARY CLASSIFICATION MODELS. \ | ||
Assess the sharpness of the model by looking at the distribution of model predictions. | ||
When forecasts are sharp, most predictions are close to either 0 or 1 \ | ||
Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) | ||
|
||
Inputs: \ | ||
- `y_binary` : the array of outputs ``y_t`` numerically coded: 1 for the target class, 0 for the negative result. \ | ||
- `sampled_distributions` : an array of sampled distributions stacked column-wise so that in the first row there is the probability for the target class ``y_1`` and in the second row the probability for the null class ``y_0``. \ | ||
|
||
Outputs: \ | ||
- `mean_class_one` : a scalar that measure the average prediction for the target class \ | ||
- `mean_class_zero` : a scalar that measure the average prediction for the null class | ||
|
||
""" | ||
function sharpness_classification(y_binary, sampled_distributions) | ||
mean_class_one = mean(sampled_distributions[1, findall(y_binary .== 1)]) | ||
mean_class_zero = mean(sampled_distributions[2, findall(y_binary .== 0)]) | ||
return mean_class_one, mean_class_zero | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
using Flux | ||
using Statistics | ||
|
||
""" | ||
get_loss_fun(likelihood::Symbol) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could just keep this consistent and return everything in both cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: my bad, let's indeed as discussed just add an option for classification to return distribution. By default, we should still return probabilities for now, but at least we give the option and add that to the docstring.