Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

90 average calibration functions in utils.jl #97

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
733312a
function empirical_frequency
pasq-cat Jun 9, 2024
25ea642
fixed the docstring.
pasq-cat Jun 9, 2024
c290ed8
added sharpness and binary classification. i have yet to test them pr…
pasq-cat Jun 14, 2024
4ff22f4
added trapz to the list of dependencies.
pasq-cat Jun 15, 2024
6a22210
added Distributions to theproject
pasq-cat Jun 15, 2024
df3d60d
working version
pasq-cat Jun 15, 2024
09f25e8
ops forgot to add sharpness for the classification case
pasq-cat Jun 15, 2024
07b318f
working release.. changed changelog, glm_predictive_distribution, pr…
pasq-cat Jun 21, 2024
eafa7bd
function empirical_frequency
pasq-cat Jun 9, 2024
f66e08e
fixed the docstring.
pasq-cat Jun 9, 2024
5355281
added sharpness and binary classification. i have yet to test them pr…
pasq-cat Jun 14, 2024
2efaa99
added trapz to the list of dependencies.
pasq-cat Jun 15, 2024
26643ee
added Distributions to theproject
pasq-cat Jun 15, 2024
b79ca39
working version
pasq-cat Jun 15, 2024
0d71736
ops forgot to add sharpness for the classification case
pasq-cat Jun 15, 2024
5f772cf
working release.. changed changelog, glm_predictive_distribution, pr…
pasq-cat Jun 21, 2024
d146d1d
Merge branch '90-average-calibration-in-utilsjl' of https://github.co…
pasq-cat Jun 21, 2024
7af9378
changed docstrings in predicting.jl
pasq-cat Jun 21, 2024
2c42236
fixed glm_predictive_distribution
pasq-cat Jun 22, 2024
9d67ddc
Update src/utils.jl
pasq-cat Jun 22, 2024
9f07583
Update src/utils.jl
pasq-cat Jun 22, 2024
f81d226
Update src/utils.jl
pasq-cat Jun 22, 2024
6cdc503
Update src/baselaplace/predicting.jl
pasq-cat Jun 22, 2024
89bb19b
Update src/baselaplace/predicting.jl
pasq-cat Jun 22, 2024
6fe01a2
JuliaFormatter
pasq-cat Jun 22, 2024
0bba488
fixed docstrings
pasq-cat Jun 23, 2024
8311de3
made docstrings a lil bit shorter
pasq-cat Jun 23, 2024
7837333
docstrings again (added output)
pasq-cat Jun 24, 2024
b0518b2
fixed binary classification case, exported function from utils.
pasq-cat Jun 24, 2024
6a9ee1b
juliaformatter
pasq-cat Jun 24, 2024
203513d
add n_bins as argument to functions
pasq-cat Jun 29, 2024
dce9bdb
ops forgot default value
pasq-cat Jun 29, 2024
b906c3b
ops forgot default value and removed a line
pasq-cat Jun 29, 2024
2059bed
Merge branch '90-average-calibration-in-utilsjl' of https://github.co…
pasq-cat Jun 29, 2024
3258618
juliaformatter----
pasq-cat Jun 29, 2024
c86dc25
fixed small error in pred_avg
pasq-cat Jun 30, 2024
3d2ebd6
fixed error in empirical_frequency_regression
pasq-cat Jun 30, 2024
4ab04f6
Update src/utils.jl
pasq-cat Jun 30, 2024
267b8f4
docstrings fixes and predict update
pasq-cat Jul 2, 2024
d188daf
fixed typos
pasq-cat Jul 2, 2024
270b70a
moved sharpness functions units tests in calibration.jl. changed run…
pasq-cat Jul 2, 2024
3320063
more sharpness unit tests
pasq-cat Jul 2, 2024
3750dbe
fixes and more unit tests
pasq-cat Jul 2, 2024
39d4bdc
small stuff
pasq-cat Jul 3, 2024
56c3b66
fix. there is still an issue with the shape of the input to use.
pasq-cat Jul 3, 2024
908c804
fixed logit.md ,moved functions to new file, removed changes to predi…
pasq-cat Jul 4, 2024
f468803
removed calibration_plots.md
pasq-cat Jul 4, 2024
459b2fe
test plot
pasq-cat Jul 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v0.2.1].

## Version [0.3.1] - 2024-06-22

### Changed

- Changed `glm_predictive_distribution` so that return a tuple(Normal distribution,fμ, fvar) rather than the tuple (mean,variance). [#90]

## Version [0.3.0] - 2024-06-21

### Changed

- Changed `glm_predictive_distribution` so that return a Normal distribution rather than the tuple (mean,variance). [#90]
- Changed `predict` so that return directly a Normal distribution in the case of regression. [#90]

### Added

- Added functions to compute the average empirical frequency for both classification and regression problems in utils.jl. [#90]





## Version [0.2.1] - 2024-05-29

### Changed
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.2.1"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
Expand All @@ -24,6 +25,7 @@ Aqua = "0.8"
ChainRulesCore = "1.23.0"
Compat = "4.7.0"
ComputationalResources = "0.3.2"
Distributions = "0.25.109"
Flux = "0.12, 0.13, 0.14"
LinearAlgebra = "1.6, 1.7, 1.8, 1.9, 1.10"
MLJFlux = "0.2.10, 0.3, 0.4"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
Expand All @@ -9,5 +10,6 @@ RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TaijaPlotting = "bd7198b4-c7d6-400c-9bab-9a24614b0240"
Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2 changes: 2 additions & 0 deletions src/LaplaceRedux.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module LaplaceRedux

include("utils.jl")
export empirical_frequency_binary_classification,
sharpness_classification, empirical_frequency_regression, sharpness_regression

include("data/Data.jl")
using .Data
Expand Down
18 changes: 13 additions & 5 deletions src/baselaplace/predicting.jl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could just keep this consistent and return everything in both cases?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: my bad, let's indeed as discussed just add an option for classification to return distribution. By default, we should still return probabilities for now, but at least we give the option and add that to the docstring.

Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Distributions: Distributions
using Statistics: mean, var
"""
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)

Expand All @@ -22,6 +24,8 @@ Computes the linearized GLM predictive.
- `fμ::AbstractArray`: Mean of the predictive distribution. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: Variance of the predictive distribution. The output shape is column-major as in Flux.

- `normal_distr` An array of normal distributions approximating the predictive distribution p(y|X) given the input data X.

# Examples

```julia-repl
Expand All @@ -39,7 +43,9 @@ function glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)
fμ = reshape(fμ, Flux.outputsize(la.model, size(X)))
fvar = functional_variance(la, 𝐉)
fvar = reshape(fvar, size(fμ)...)
return fμ, fvar
fstd = sqrt.(fvar)
normal_distr = [Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)]
return (normal_distr, fμ, fvar)
end

"""
Expand All @@ -55,9 +61,10 @@ Computes predictions from Bayesian neural network.
- `predict_proba::Bool=true`: If `true` (default), returns probabilities for classification tasks.

# Returns

For classification tasks:
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux.
- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux.
For regression tasks:
- `normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution. The output shape is column-major as in Flux.

# Examples

Expand All @@ -75,11 +82,12 @@ predict(la, hcat(x...))
function predict(
la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true
)
fμ, fvar = glm_predictive_distribution(la, X)
normal_distr, fμ, fvar = glm_predictive_distribution(la, X)
#fμ, fvar = mean.(normal_distr), var.(normal_distr)

# Regression:
if la.likelihood == :regression
return fμ, fvar
return normal_distr
end

# Classification:
Expand Down
126 changes: 126 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Flux
using Statistics

"""
get_loss_fun(likelihood::Symbol)
Expand Down Expand Up @@ -39,3 +40,128 @@
function outdim(model::Chain)::Number
return [size(p) for p in Flux.params(model)][end][1]
end

"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice detail here! Just a formatting thing: can you use maths notation that's recognized by Documenter.jl please? See here for an example and here for the docs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i fixed the docstrings but for a new line ( without the empy line in the middle), i had to use the \ character.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i checked the results in the julia repl but i was not able to make the function appear in the documentation. i tried to go in docs/ and then type julia make.jl and i got "[ Info: SetupBuildDirectory: setting up build directory.
[ Info: Doctest: running doctests.
[ Info: ExpandTemplates: expanding markdown templates.
[ Info: CrossReferences: building cross-references.
[ Info: CheckDocument: running document checks.
[ Info: Populate: populating indices.
[ Info: RenderDocument: rendering document.
[ Info: HTMLWriter: rendering HTML pages.
┌ Warning: Unable to determine the repository root URL for the navbar link.
│ This can happen when a string is passed to the repo keyword of makedocs.

│ To remove this warning, either pass a Remotes.Remote object to repo to completely
│ specify the remote repository, or explicitly set the remote URL by setting repolink
│ via makedocs(format = HTML(repolink = "..."), ...).
└ @ Documenter.HTMLWriter C:\Users\Pasqu.julia\packages\Documenter\qoyeC\src\html\HTMLWriter.jl:732
[ Info: Automatic version="0.2.1" for inventory from ..\Project.toml
┌ Warning: Documenter could not auto-detect the building environment. Skipping deployment.
└ @ Documenter C:\Users\Pasqu.julia\packages\Documenter\qoyeC\src\deployconfig.jl:76 "
so i opened index.html with edge through vs code and my functions do not appear in the documentations when i search for them. do i have to modify some sort of flag?

empirical_frequency(Y_cal, sampled_distributions)

FOR REGRESSION MODELS.
Given a calibration dataset (x_t, y_t) for i ∈ {1,...,T} and an array of predicted distributions, the function calculates the empirical frequency
phat_j = {y_t|F_t(y_t)<= p_j, t= 1,....,T}/T, where T is the number of calibration points, p_j is the confidence level and F_t is the
cumulative distribution function of the predicted distribution targeting y_t.
Source: https://arxiv.org/abs/1807.00263

Inputs:
- 'Y_cal': a vector of values y_t
- 'sampled_distributions': an array of sampled distributions F(x_t) stacked column-wise.
- 'n_bins': number of equally spaced bins to use.
Outputs:
- counts: an array cointaining the empirical frequencies for each quantile interval.
"""
function empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=20)
quantiles = collect(range(0; stop=1, length=n_bins + 1))
quantiles_matrix = hcat(

Check warning on line 62 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L60-L62

Added lines #L60 - L62 were not covered by tests
[quantile(samples, quantiles) for samples in sampled_distributions]...
)
n_rows = size(bounds_quantiles_matrix, 1)
counts = []

Check warning on line 66 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L65-L66

Added lines #L65 - L66 were not covered by tests

for i in 1:n_rows
push!(counts, sum(Y_cal .<= quantiles_matrix[i, :]) / length(Y_cal))
end
return counts

Check warning on line 71 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L68-L71

Added lines #L68 - L71 were not covered by tests
end

"""
sharpness(sampled_distributions)

FOR REGRESSION MODELS.
Given a calibration dataset (x_t, y_t) for i ∈ {1,...,T} and an array of predicted distributions, the function calculates the
sharpness of the predicted distributions, i.e., the average of the variances var(F_t) predicted by the forecaster for each x_t
Source: https://arxiv.org/abs/1807.00263

Inputs:
- sampled_distributions: an array of sampled distributions F(x_t) stacked column-wise.
Outputs:
- sharpness: a scalar that measure the level of sharpness of the regressor
"""
function sharpness_regression(sampled_distributions)
sharpness = mean(var.(sampled_distributions))
return sharpness

Check warning on line 89 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L87-L89

Added lines #L87 - L89 were not covered by tests
end

"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above comment on maths notation

empirical_frequency-classification(y_binary, sampled_distributions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dash should be an underscore


FOR BINARY CLASSIFICATION MODELS.
Given a calibration dataset (x_t, y_t) for i ∈ {1,...,T} let p_t= H(x_t)∈[0,1] be the forecasted probability.
We group the p_t into intervals I-j for j= 1,2,...,m that form a partition of [0,1]. The function computes
the observed average p_j= T^-1_j ∑_{t:p_t ∈ I_j} y_j in each interval I_j.
Source: https://arxiv.org/abs/1807.00263

Inputs:
- y_binary: the array of outputs y_t numerically coded: 1 for the target class, 0 for the null class.
- sampled_distributions: an array of sampled distributions stacked column-wise so that in the first row
there is the probability for the target class y_1 and in the second row the probability for the null class y_0.
- 'n_bins': number of equally spaced bins to use.
Outputs:
- num_p_per_interval: array with the number of probabilities falling within interval
- emp_avg: array with the observed empirical average per interval
- bin_centers: array with the centers of the bins

"""
function empirical_frequency_binary_classification(

Check warning on line 112 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L112

Added line #L112 was not covered by tests
y_binary, sampled_distributions, n_bins=20
)
#intervals boundaries
int_bds = collect(range(0; stop=1, length=n_bins + 1))

Check warning on line 116 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L116

Added line #L116 was not covered by tests
#bin centers
bin_centers = [(int_bds[i] + int_bds[i + 1]) / 2 for i in 1:(length(int_bds) - 1)]

Check warning on line 118 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L118

Added line #L118 was not covered by tests
#initialize list for empirical averages per interval
emp_avg = []

Check warning on line 120 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L120

Added line #L120 was not covered by tests
#initialize list for predicted averages per interval
pred_avg = []

Check warning on line 122 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L122

Added line #L122 was not covered by tests
# initialize list of number of probabilities falling within each intervals
num_p_per_interval = []

Check warning on line 124 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L124

Added line #L124 was not covered by tests
#list of the predicted probabilities for the target class
class_probs = sampled_distributions[1, :]

Check warning on line 126 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L126

Added line #L126 was not covered by tests
# iterate over the bins
for j in 1:n_bins
push!(num_p_per_interval, sum(int_bds[j] .< class_probs .< int_bds[j + 1]))
if num_p_per_interval[j] == 0
push!(emp_avg, 0)
push!(pred_avg, bin_centers[j])

Check warning on line 132 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L128-L132

Added lines #L128 - L132 were not covered by tests

else
# find the indices fo all istances for which class_probs fall withing the j-th interval
indices = findall(x -> int_bds[j] < x < int_bds[j + 1], class_probs)

Check warning on line 136 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L136

Added line #L136 was not covered by tests
#compute the empirical average and saved it in emp_avg in the j-th position
push!(emp_avg, 1 / num_p_per_interval[j] * sum(y_binary[indices]))

Check warning on line 138 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L138

Added line #L138 was not covered by tests
#TO DO: maybe substitute to bin_Centers?
push!(pred_avg, 1 / num_p_per_interval[j] * sum(class_probs[ indices]))

Check warning on line 140 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L140

Added line #L140 was not covered by tests
pasq-cat marked this conversation as resolved.
Show resolved Hide resolved
end
end

Check warning on line 142 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L142

Added line #L142 was not covered by tests
#return the tuple
return (num_p_per_interval, emp_avg, bin_centers)

Check warning on line 144 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L144

Added line #L144 was not covered by tests
end

"""
sharpness-classification(y_binary,sampled_distributions)

FOR BINARY CLASSIFICATION MODELS.
Assess the sharpness of the model by looking at the distribution of model predictions. When forecasts are sharp,
most predictions are close to 0 or 1; not sharp forecasters make predictions closer to 0.5.
Source: https://arxiv.org/abs/1807.00263

Inputs:
-y_binary: the array of outputs y_t numerically coded . 1 for the target class, 0 for the negative result.
-sampled_distributions: an array of sampled distributions stacked column-wise so that in the first row
there is the probability for the target class and in the second row the probability for the null class.
Outputs:
- mean_class_one: a scalar that measure the average prediction for the target class
- mean_class_zero: a scalar that measure the average prediction for the null class
"""
function sharpness_classification(y_binary, sampled_distributions)
mean_class_one = mean(sampled_distributions[1, findall(y_binary .== 1)])
mean_class_zero = mean(sampled_distributions[2, findall(y_binary .== 0)])
return mean_class_one, mean_class_zero

Check warning on line 166 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L163-L166

Added lines #L163 - L166 were not covered by tests
end
Loading