-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
90 average calibration functions in utils.jl #97
Changes from 41 commits
733312a
25ea642
c290ed8
4ff22f4
6a22210
df3d60d
09f25e8
07b318f
eafa7bd
f66e08e
5355281
2efaa99
26643ee
b79ca39
0d71736
5f772cf
d146d1d
7af9378
2c42236
9d67ddc
9f07583
f81d226
6cdc503
89bb19b
6fe01a2
0bba488
8311de3
7837333
b0518b2
6a9ee1b
203513d
dce9bdb
b906c3b
2059bed
3258618
c86dc25
3d2ebd6
4ab04f6
267b8f4
d188daf
270b70a
3320063
3750dbe
39d4bdc
56c3b66
908c804
f468803
459b2fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
using Distributions: Distributions | ||
using Statistics: mean, var | ||
""" | ||
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray) | ||
|
||
|
@@ -22,6 +24,8 @@ Computes the linearized GLM predictive. | |
- `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux. | ||
- `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux. | ||
|
||
- `normal_distr` An array of normal distributions approximating the predictive distribution p(y|X) given the input data X. | ||
|
||
# Examples | ||
|
||
```julia-repl | ||
|
@@ -39,7 +43,9 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray) | |
fμ = reshape(fμ, Flux.outputsize(la.model, size(X))) | ||
fvar = functional_variance(la, 𝐉) | ||
fvar = reshape(fvar, size(fμ)...) | ||
return fμ, fvar | ||
fstd = sqrt.(fvar) | ||
normal_distr = [Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)] | ||
return (normal_distr, fμ, fvar) | ||
end | ||
|
||
""" | ||
|
@@ -55,9 +61,12 @@ Computes predictions from Bayesian neural network. | |
- `predict_proba::Bool=true`: If `true` (default), returns probabilities for classification tasks. | ||
|
||
# Returns | ||
|
||
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux. | ||
- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux. | ||
For classification tasks, LaplaceRedux provides different options: | ||
-`normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution If the `link_approx` is set to :distribution | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pat-alt i am confused on how to proceed here. the issue is that the output of the chain may have already passed through a softmax layer, so the output should not be converted again. should we add a check for this or leave it to the educated reader lol? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! I think we should probably add a check somewhere in the corestruct that contains the Flux chain. Let's open a separate issue for this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. either that or modify the predict function so that it checks that the la.model has a finaliser layer |
||
-`fμ::AbstractArray` Mean of the normal distribution if link_approx is set to :plugin | ||
-`fμ::AbstractArray` The probit approximation if link_approx is set to :probit | ||
For regression tasks: | ||
- `normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution. | ||
|
||
# Examples | ||
|
||
|
@@ -75,16 +84,22 @@ predict(la, hcat(x...)) | |
function predict( | ||
la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true | ||
) | ||
fμ, fvar = glm_predictive_distribution(la, X) | ||
normal_distr, fμ, fvar = glm_predictive_distribution(la, X) | ||
#fμ, fvar = mean.(normal_distr), var.(normal_distr) | ||
|
||
# Regression: | ||
if la.likelihood == :regression | ||
return fμ, fvar | ||
return normal_distr | ||
end | ||
|
||
# Classification: | ||
if la.likelihood == :classification | ||
|
||
# Probit approximation | ||
if link_approx == :distribution | ||
z = normal_distr | ||
end | ||
|
||
# Probit approximation | ||
if link_approx == :probit | ||
z = probit(fμ, fvar) | ||
|
@@ -95,7 +110,7 @@ function predict( | |
end | ||
|
||
# Sigmoid/Softmax | ||
if predict_proba | ||
if (predict_proba && link_approx != :distribution) | ||
if la.posterior.n_out == 1 | ||
p = Flux.sigmoid(z) | ||
else | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -1,4 +1,5 @@ | ||||||||||
using Flux | ||||||||||
using Statistics | ||||||||||
|
||||||||||
""" | ||||||||||
get_loss_fun(likelihood::Symbol) | ||||||||||
|
@@ -39,3 +40,141 @@ corresponding to the number of neurons on the last layer of the NN. | |||||||||
function outdim(model::Chain)::Number | ||||||||||
return [size(p) for p in Flux.params(model)][end][1] | ||||||||||
end | ||||||||||
|
||||||||||
@doc raw""" | ||||||||||
empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=20) | ||||||||||
|
||||||||||
FOR REGRESSION MODELS. \ | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pat-alt maybe i can move this functions to a dedicated julia file ( calibration_functions.jl) , so that in the future i may add something else in a compartmentalized file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, good idea |
||||||||||
Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` and an array of predicted distributions, the function calculates the empirical frequency | ||||||||||
```math | ||||||||||
p^hat_j = {y_t|F_t(y_t)<= p_j, t= 1,....,T}/T, | ||||||||||
``` | ||||||||||
where ``T`` is the number of calibration points, ``p_j`` is the confidence level and ``F_t`` is the | ||||||||||
cumulative distribution function of the predicted distribution targeting ``y_t``. \ | ||||||||||
Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) | ||||||||||
|
||||||||||
Inputs: \ | ||||||||||
- `Y_cal`: a vector of values ``y_t``\ | ||||||||||
- `sampled_distributions`: an array of sampled distributions ``F(x_t)`` stacked column-wise.\ | ||||||||||
- `n_bins`: number of equally spaced bins to use.\ | ||||||||||
Outputs:\ | ||||||||||
- `counts`: an array cointaining the empirical frequencies for each quantile interval. | ||||||||||
""" | ||||||||||
function empirical_frequency_regression(Y_cal, sampled_distributions, n_bins::Int=20) | ||||||||||
if n_bins <= 0 | ||||||||||
throw(ArgumentError("n_bins must be a positive integer")) | ||||||||||
elseif all(x -> x == 0 || x == 1, y_binary) | ||||||||||
throw(ArgumentError("y_binary must be an array of 0 and 1")) | ||||||||||
end | ||||||||||
quantiles = collect(range(0; stop=1, length=n_bins + 1)) | ||||||||||
quantiles_matrix = hcat( | ||||||||||
[quantile(samples, quantiles) for samples in sampled_distributions]... | ||||||||||
) | ||||||||||
n_rows = size(quantiles_matrix, 1) | ||||||||||
counts = [] | ||||||||||
|
||||||||||
for i in 1:n_rows | ||||||||||
push!(counts, sum(Y_cal .<= quantiles_matrix[i, :]) / length(Y_cal)) | ||||||||||
end | ||||||||||
return counts | ||||||||||
end | ||||||||||
|
||||||||||
@doc raw""" | ||||||||||
sharpness_regression(sampled_distributions) | ||||||||||
|
||||||||||
FOR REGRESSION MODELS. \ | ||||||||||
Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` and an array of predicted distributions, the function calculates the | ||||||||||
sharpness of the predicted distributions, i.e., the average of the variances ``\sigma^2(F_t)`` predicted by the forecaster for each ``x_t``. \ | ||||||||||
source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) | ||||||||||
|
||||||||||
Inputs: \ | ||||||||||
- `sampled_distributions`: an array of sampled distributions ``F(x_t)`` stacked column-wise. \ | ||||||||||
Outputs: \ | ||||||||||
- `sharpness`: a scalar that measure the level of sharpness of the regressor | ||||||||||
""" | ||||||||||
function sharpness_regression(sampled_distributions) | ||||||||||
sharpness = mean(var.(sampled_distributions)) | ||||||||||
return sharpness | ||||||||||
end | ||||||||||
|
||||||||||
@doc raw""" | ||||||||||
empirical_frequency_classification(y_binary, sampled_distributions) | ||||||||||
|
||||||||||
FOR BINARY CLASSIFICATION MODELS.\ | ||||||||||
Given a calibration dataset ``(x_t, y_t)`` for ``i ∈ {1,...,T}`` let ``p_t= H(x_t)∈[0,1]`` be the forecasted probability. \ | ||||||||||
We group the ``p_t`` into intervals ``I_j`` for ``j= 1,2,...,m`` that form a partition of [0,1]. | ||||||||||
The function computes the observed average ``p_j= T^-1_j ∑_{t:p_t ∈ I_j} y_j`` in each interval ``I_j``. \ | ||||||||||
Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) | ||||||||||
|
||||||||||
Inputs: \ | ||||||||||
- `y_binary`: the array of outputs ``y_t`` numerically coded: 1 for the target class, 0 for the null class. \ | ||||||||||
- `sampled_distributions`: an array of sampled distributions stacked column-wise so that in the first row | ||||||||||
there is the probability for the target class ``y_1`` and in the second row the probability for the null class ``y_0``. \ | ||||||||||
- `n_bins`: number of equally spaced bins to use. | ||||||||||
|
||||||||||
Outputs: \ | ||||||||||
- `num_p_per_interval`: array with the number of probabilities falling within interval. \ | ||||||||||
- `emp_avg`: array with the observed empirical average per interval. \ | ||||||||||
- `bin_centers`: array with the centers of the bins. | ||||||||||
|
||||||||||
""" | ||||||||||
function empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins::Int=20) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||
if n_bins <= 0 | ||||||||||
throw(ArgumentError("n_bins must be a positive integer")) | ||||||||||
elseif all(x -> x == 0 || x == 1, y_binary) | ||||||||||
throw(ArgumentError("y_binary must be an array of 0 and 1")) | ||||||||||
end | ||||||||||
#intervals boundaries | ||||||||||
int_bds = collect(range(0; stop=1, length=n_bins + 1)) | ||||||||||
#bin centers | ||||||||||
bin_centers = [(int_bds[i] + int_bds[i + 1]) / 2 for i in 1:(length(int_bds) - 1)] | ||||||||||
#initialize list for empirical averages per interval | ||||||||||
emp_avg = [] | ||||||||||
#initialize list for predicted averages per interval | ||||||||||
pred_avg = [] | ||||||||||
# initialize list of number of probabilities falling within each intervals | ||||||||||
num_p_per_interval = [] | ||||||||||
#list of the predicted probabilities for the target class | ||||||||||
class_probs = sampled_distributions[1, :] | ||||||||||
# iterate over the bins | ||||||||||
for j in 1:n_bins | ||||||||||
push!(num_p_per_interval, sum(int_bds[j] .< class_probs .< int_bds[j + 1])) | ||||||||||
if num_p_per_interval[j] == 0 | ||||||||||
push!(emp_avg, 0) | ||||||||||
push!(pred_avg, bin_centers[j]) | ||||||||||
|
||||||||||
else | ||||||||||
# find the indices fo all istances for which class_probs fall withing the j-th interval | ||||||||||
indices = findall(x -> int_bds[j] < x < int_bds[j + 1], class_probs) | ||||||||||
#compute the empirical average and saved it in emp_avg in the j-th position | ||||||||||
push!(emp_avg, 1 / num_p_per_interval[j] * sum(y_binary[indices])) | ||||||||||
#TO DO: maybe substitute to bin_Centers? | ||||||||||
push!(pred_avg, 1 / num_p_per_interval[j] * sum(class_probs[indices])) | ||||||||||
end | ||||||||||
end | ||||||||||
#return the tuple | ||||||||||
return (num_p_per_interval, emp_avg, bin_centers) | ||||||||||
end | ||||||||||
|
||||||||||
@doc raw""" | ||||||||||
sharpness_classification(y_binary,sampled_distributions) | ||||||||||
|
||||||||||
FOR BINARY CLASSIFICATION MODELS. \ | ||||||||||
Assess the sharpness of the model by looking at the distribution of model predictions. | ||||||||||
When forecasts are sharp, most predictions are close to either 0 or 1 \ | ||||||||||
Source: [Kuleshov, Fenner, Ermon 2018](https://arxiv.org/abs/1807.00263) | ||||||||||
|
||||||||||
Inputs: \ | ||||||||||
- `y_binary` : the array of outputs ``y_t`` numerically coded: 1 for the target class, 0 for the negative result. \ | ||||||||||
- `sampled_distributions` : an array of sampled distributions stacked column-wise so that in the first row there is the probability for the target class ``y_1`` and in the second row the probability for the null class ``y_0``. \ | ||||||||||
|
||||||||||
Outputs: \ | ||||||||||
- `mean_class_one` : a scalar that measure the average prediction for the target class \ | ||||||||||
- `mean_class_zero` : a scalar that measure the average prediction for the null class | ||||||||||
|
||||||||||
""" | ||||||||||
function sharpness_classification(y_binary, sampled_distributions) | ||||||||||
mean_class_one = mean(sampled_distributions[1, findall(y_binary .== 1)]) | ||||||||||
mean_class_zero = mean(sampled_distributions[2, findall(y_binary .== 0)]) | ||||||||||
return mean_class_one, mean_class_zero | ||||||||||
end |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,25 @@ | ||||||||||||||||||||||||||||||||||
using Statistics | ||||||||||||||||||||||||||||||||||
using LaplaceRedux | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
@testset "sharpness_classification tests" begin | ||||||||||||||||||||||||||||||||||
y_binary = [0, 1, 0, 1, 1, 0, 1, 0] | ||||||||||||||||||||||||||||||||||
sampled_distributions = [ | ||||||||||||||||||||||||||||||||||
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8; | ||||||||||||||||||||||||||||||||||
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | ||||||||||||||||||||||||||||||||||
@test mean_class_one ≈ mean(sampled_distributions[1,[2,4,5,7]]) | ||||||||||||||||||||||||||||||||||
@test mean_class_zero ≈ mean(sampled_distributions[2,[1,3,6,8]]) | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
# Test for `sharpness_regression` function | ||||||||||||||||||||||||||||||||||
@testset "sharpness_regression tests" begin | ||||||||||||||||||||||||||||||||||
sampled_distributions = [[0.1, 0.2, 0.3, 0.7, 0.6], [0.2, 0.3, 0.4, 0.3 , 0.5 ], [0.3, 0.4, 0.5, 0.9, 0.2]] | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||||||||||||||
mean_variance = mean(map(var, sampled_distributions)) | ||||||||||||||||||||||||||||||||||
sharpness = sharpness_regression(sampled_distributions) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
@test sharpness ≈ mean_variance | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could just keep this consistent and return everything in both cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: my bad, let's indeed as discussed just add an option for classification to return distribution. By default, we should still return probabilities for now, but at least we give the option and add that to the docstring.