-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
90 average calibration functions in utils.jl #97
Conversation
…operly due to the issue with mljflux.
…operly due to the issue with mljflux.
…m/JuliaTrustworthyAI/LaplaceRedux.jl into 90-average-calibration-in-utilsjl
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/utils.jl
Outdated
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/baselaplace/predicting.jl
Outdated
normal_distr = [ | ||
Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)] | ||
return (normal_distr,fμ,fvar ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
normal_distr = [ | |
Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)] | |
return (normal_distr,fμ,fvar ) | |
normal_distr = [Distributions.Normal(fμ[i], fstd[i]) for i in 1:size(fμ, 1)] | |
return (normal_distr, fμ, fvar) |
src/baselaplace/predicting.jl
Outdated
@@ -75,11 +83,12 @@ predict(la, hcat(x...)) | |||
function predict( | |||
la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true | |||
) | |||
fμ, fvar = glm_predictive_distribution(la, X) | |||
normal_distr,fμ, fvar = glm_predictive_distribution(la, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
normal_distr,fμ, fvar = glm_predictive_distribution(la, X) | |
normal_distr, fμ, fvar = glm_predictive_distribution(la, X) |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/baselaplace/predicting.jl
Outdated
@@ -75,11 +82,12 @@ predict(la, hcat(x...)) | |||
function predict( | |||
la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true | |||
) | |||
fμ, fvar = glm_predictive_distribution(la, X) | |||
normal_distr,fμ, fvar = glm_predictive_distribution(la, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
normal_distr,fμ, fvar = glm_predictive_distribution(la, X) | |
normal_distr, fμ, fvar = glm_predictive_distribution(la, X) |
yeah basically juliaformatter being a pain for nothing Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice work here @Rockdeldiablo 👍🏽 just a few missing pieces:
- check individual comments (minor things)
- add a tutorial to the documentation
- add unit tests (at least passing through functions and having simple tests for expected output type, then new issue for testing with synthetic data).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could just keep this consistent and return everything in both cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: my bad, let's indeed as discussed just add an option for classification to return distribution. By default, we should still return probabilities for now, but at least we give the option and add that to the docstring.
src/utils.jl
Outdated
@@ -39,3 +40,128 @@ corresponding to the number of neurons on the last layer of the NN. | |||
function outdim(model::Chain)::Number | |||
return [size(p) for p in Flux.params(model)][end][1] | |||
end | |||
|
|||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i fixed the docstrings but for a new line ( without the empy line in the middle), i had to use the \ character.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i checked the results in the julia repl but i was not able to make the function appear in the documentation. i tried to go in docs/ and then type julia make.jl and i got "[ Info: SetupBuildDirectory: setting up build directory.
[ Info: Doctest: running doctests.
[ Info: ExpandTemplates: expanding markdown templates.
[ Info: CrossReferences: building cross-references.
[ Info: CheckDocument: running document checks.
[ Info: Populate: populating indices.
[ Info: RenderDocument: rendering document.
[ Info: HTMLWriter: rendering HTML pages.
┌ Warning: Unable to determine the repository root URL for the navbar link.
│ This can happen when a string is passed to the repo
keyword of makedocs
.
│
│ To remove this warning, either pass a Remotes.Remote object to repo
to completely
│ specify the remote repository, or explicitly set the remote URL by setting repolink
│ via makedocs(format = HTML(repolink = "..."), ...)
.
└ @ Documenter.HTMLWriter C:\Users\Pasqu.julia\packages\Documenter\qoyeC\src\html\HTMLWriter.jl:732
[ Info: Automatic version="0.2.1"
for inventory from ..\Project.toml
┌ Warning: Documenter could not auto-detect the building environment. Skipping deployment.
└ @ Documenter C:\Users\Pasqu.julia\packages\Documenter\qoyeC\src\deployconfig.jl:76 "
so i opened index.html with edge through vs code and my functions do not appear in the documentations when i search for them. do i have to modify some sort of flag?
src/utils.jl
Outdated
return sharpness | ||
end | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above comment on maths notation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding your question on docs, just make sure that the function signature stated in the docstring matches the function
src/utils.jl
Outdated
""" | ||
function sharpness_regression(sampled_distributions) | ||
sharpness = mean(var.(sampled_distributions)) | ||
return sharpness | ||
end | ||
|
||
""" | ||
@doc raw""" | ||
empirical_frequency-classification(y_binary, sampled_distributions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dash should be an underscore
src/utils.jl
Outdated
@@ -144,21 +148,22 @@ function empirical_frequency_binary_classification( | |||
return (num_p_per_interval, emp_avg, bin_centers) | |||
end | |||
|
|||
""" | |||
sharpness-classification(y_binary,sampled_distributions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, dash should be an underscore
…ests.jl to add the new file.
src/utils.jl
Outdated
- `bin_centers`: array with the centers of the bins. | ||
|
||
""" | ||
function empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins::Int=20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
function empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins::Int=20) | |
function empirical_frequency_binary_classification( | |
y_binary, sampled_distributions, n_bins::Int=20 | |
) |
test/calibration.jl
Outdated
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8; | ||
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8; | |
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 | |
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 | |
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 |
test/calibration.jl
Outdated
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | ||
@test mean_class_one ≈ mean(sampled_distributions[1,[2,4,5,7]]) | ||
@test mean_class_zero ≈ mean(sampled_distributions[2,[1,3,6,8]]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | |
@test mean_class_one ≈ mean(sampled_distributions[1,[2,4,5,7]]) | |
@test mean_class_zero ≈ mean(sampled_distributions[2,[1,3,6,8]]) | |
mean_class_one, mean_class_zero = sharpness_classification( | |
y_binary, sampled_distributions | |
) | |
@test mean_class_one ≈ mean(sampled_distributions[1, [2, 4, 5, 7]]) | |
@test mean_class_zero ≈ mean(sampled_distributions[2, [1, 3, 6, 8]]) |
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
|
||
# Test for `sharpness_regression` function | ||
@testset "sharpness_regression tests" begin | ||
sampled_distributions = [[0.1, 0.2, 0.3, 0.7, 0.6], [0.2, 0.3, 0.4, 0.3 , 0.5 ], [0.3, 0.4, 0.5, 0.9, 0.2]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = [[0.1, 0.2, 0.3, 0.7, 0.6], [0.2, 0.3, 0.4, 0.3 , 0.5 ], [0.3, 0.4, 0.5, 0.9, 0.2]] | |
sampled_distributions = [ | |
[0.1, 0.2, 0.3, 0.7, 0.6], [0.2, 0.3, 0.4, 0.3, 0.5], [0.3, 0.4, 0.5, 0.9, 0.2] | |
] |
test/calibration.jl
Outdated
sharpness = sharpness_regression(sampled_distributions) | ||
|
||
@test sharpness ≈ mean_variance | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
test/calibration.jl
Outdated
|
||
@testset "sharpness_classification tests" begin | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
# Test 1: Check that the function runs without errors and returns two scalars for a simple case | ||
y_binary = [1, 0, 1, 0, 1] | ||
sampled_distributions = [0.9 0.1 0.8 0.2 0.7; 0.1 0.9 0.2 0.8 0.3] # Sampled probabilities | ||
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | |
mean_class_one, mean_class_zero = sharpness_classification( | |
y_binary, sampled_distributions | |
) |
test/calibration.jl
Outdated
@test typeof(mean_class_one) <: Real # Check if mean_class_one is a scalar | ||
@test typeof(mean_class_zero) <: Real # Check if mean_class_zero is a scalar | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8; | ||
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8; | |
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 | |
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 | |
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 |
test/calibration.jl
Outdated
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | ||
@test mean_class_one ≈ mean(sampled_distributions[1,[2,4,5,7]]) | ||
@test mean_class_zero ≈ mean(sampled_distributions[2,[1,3,6,8]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | |
@test mean_class_one ≈ mean(sampled_distributions[1,[2,4,5,7]]) | |
@test mean_class_zero ≈ mean(sampled_distributions[2,[1,3,6,8]]) | |
mean_class_one, mean_class_zero = sharpness_classification( | |
y_binary, sampled_distributions | |
) | |
@test mean_class_one ≈ mean(sampled_distributions[1, [2, 4, 5, 7]]) | |
@test mean_class_zero ≈ mean(sampled_distributions[2, [1, 3, 6, 8]]) |
test/calibration.jl
Outdated
# Test 4: Edge case with all zeros in y_binary | ||
y_binary_all_zeros = [0, 0, 0] | ||
sampled_distributions_all_zeros = [0.1 0.2 0.3; 0.9 0.8 0.7] | ||
mean_class_one_all_zeros, mean_class_zero_all_zeros = sharpness_classification(y_binary_all_zeros, sampled_distributions_all_zeros) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
mean_class_one_all_zeros, mean_class_zero_all_zeros = sharpness_classification(y_binary_all_zeros, sampled_distributions_all_zeros) | |
mean_class_one_all_zeros, mean_class_zero_all_zeros = sharpness_classification( | |
y_binary_all_zeros, sampled_distributions_all_zeros | |
) |
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
@test typeof(sharpness) <: Real # Check if the output is a scalar | ||
|
||
# Test 2: Check the function with a known input | ||
sampled_distributions = [[0.1, 0.2, 0.3, 0.7, 0.6], [0.2, 0.3, 0.4, 0.3 , 0.5 ], [0.3, 0.4, 0.5, 0.9, 0.2]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = [[0.1, 0.2, 0.3, 0.7, 0.6], [0.2, 0.3, 0.4, 0.3 , 0.5 ], [0.3, 0.4, 0.5, 0.9, 0.2]] | |
sampled_distributions = [ | |
[0.1, 0.2, 0.3, 0.7, 0.6], [0.2, 0.3, 0.4, 0.3, 0.5], [0.3, 0.4, 0.5, 0.9, 0.2] | |
] |
test/calibration.jl
Outdated
|
||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
end |
is it possible to stop this checks until i am ready? every commit github automatically runs checks ..... |
test/calibration.jl
Outdated
|
||
@testset "sharpness_classification tests" begin | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
# Test 1: Check that the function runs without errors and returns two scalars for a simple case | ||
y_binary = [1, 0, 1, 0, 1] | ||
sampled_distributions = [0.9 0.1 0.8 0.2 0.7; 0.1 0.9 0.2 0.8 0.3] # Sampled probabilities | ||
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | |
mean_class_one, mean_class_zero = sharpness_classification( | |
y_binary, sampled_distributions | |
) |
test/calibration.jl
Outdated
@test typeof(mean_class_one) <: Real # Check if mean_class_one is a scalar | ||
@test typeof(mean_class_zero) <: Real # Check if mean_class_zero is a scalar | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8; | ||
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8; | |
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 | |
0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 | |
0.9 0.8 0.7 0.6 0.5 0.4 0.3 0.2 |
test/calibration.jl
Outdated
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | ||
@test mean_class_one ≈ mean(sampled_distributions[1,[2,4,5,7]]) | ||
@test mean_class_zero ≈ mean(sampled_distributions[2,[1,3,6,8]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
mean_class_one, mean_class_zero = sharpness_classification(y_binary, sampled_distributions) | |
@test mean_class_one ≈ mean(sampled_distributions[1,[2,4,5,7]]) | |
@test mean_class_zero ≈ mean(sampled_distributions[2,[1,3,6,8]]) | |
mean_class_one, mean_class_zero = sharpness_classification( | |
y_binary, sampled_distributions | |
) | |
@test mean_class_one ≈ mean(sampled_distributions[1, [2, 4, 5, 7]]) | |
@test mean_class_zero ≈ mean(sampled_distributions[2, [1, 3, 6, 8]]) |
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
y_binary = rand(0:1, 10) | ||
sampled_distributions = rand(Normal(0.5, 0.1), 10, 6) | ||
n_bins = 4 | ||
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins=n_bins) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins=n_bins) | |
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification( | |
y_binary, sampled_distributions; n_bins=n_bins | |
) |
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
Y_cal = [0, 1, 0, 1.2, 4] | ||
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | ||
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=10) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Y_cal = [0, 1, 0, 1.2, 4] | |
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | |
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=10) | |
Y_cal = [0, 1, 0, 1.2, 4] | |
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | |
@test_throws ArgumentError empirical_frequency_binary_classification( | |
Y_cal, sampled_distributions, n_bins=10 | |
) |
test/calibration.jl
Outdated
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | ||
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=0) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | |
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=0) | |
end | |
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | |
@test_throws ArgumentError empirical_frequency_binary_classification( | |
Y_cal, sampled_distributions, n_bins=0 | |
) | |
end |
src/baselaplace/predicting.jl
Outdated
- `fμ::AbstractArray`: Mean of the predictive distribution if link function is set to `:plugin`, otherwise the probit approximation. The output shape is column-major as in Flux. | ||
- `fvar::AbstractArray`: If regression, it also returns the variance of the predictive distribution. The output shape is column-major as in Flux. | ||
For classification tasks, LaplaceRedux provides different options: | ||
-`normal_distr::Distributions.Normal`:the array of Normal distributions computed by glm_predictive_distribution If the `link_approx` is set to :distribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pat-alt i am confused on how to proceed here. the issue is that the output of the chain may have already passed through a softmax layer, so the output should not be converted again. should we add a check for this or leave it to the educated reader lol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I think we should probably add a check somewhere in the corestruct that contains the Flux chain. Let's open a separate issue for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either that or modify the predict function so that it checks that the la.model has a finaliser layer
src/utils.jl
Outdated
@doc raw""" | ||
empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=20) | ||
|
||
FOR REGRESSION MODELS. \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pat-alt maybe i can move this functions to a dedicated julia file ( calibration_functions.jl) , so that in the future i may add something else in a compartmentalized file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, good idea
test/calibration.jl
Outdated
sampled_distributions = [rand(Distributions.Normal(1, 1.0),6) for _ in 1:5] | ||
@test_throws ArgumentError empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = [rand(Distributions.Normal(1, 1.0),6) for _ in 1:5] | |
@test_throws ArgumentError empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=0) | |
sampled_distributions = [rand(Distributions.Normal(1, 1.0), 6) for _ in 1:5] | |
@test_throws ArgumentError empirical_frequency_regression( | |
Y_cal, sampled_distributions, n_bins=0 | |
) |
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
y_binary = rand(0:1, 10) | ||
sampled_distributions = rand(Normal(0.5, 0.1), 10, 6) | ||
n_bins = 4 | ||
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins=n_bins) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins=n_bins) | |
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification( | |
y_binary, sampled_distributions; n_bins=n_bins | |
) |
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
Y_cal = [0, 1, 0, 1.2, 4] | ||
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | ||
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=10) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Y_cal = [0, 1, 0, 1.2, 4] | |
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | |
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=10) | |
Y_cal = [0, 1, 0, 1.2, 4] | |
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | |
@test_throws ArgumentError empirical_frequency_binary_classification( | |
Y_cal, sampled_distributions, n_bins=10 | |
) |
test/calibration.jl
Outdated
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | ||
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=0) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | |
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=0) | |
end | |
sampled_distributions = rand(Normal(0.5, 0.1), 5, 6) | |
@test_throws ArgumentError empirical_frequency_binary_classification( | |
Y_cal, sampled_distributions, n_bins=0 | |
) | |
end |
test/calibration.jl
Outdated
sampled_distributions = [rand(Distributions.Normal(1, 1.0),6) for _ in 1:5] | ||
counts = empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=n_bins) | ||
@test typeof(counts) == Array{Float64, 1} # Check if the output is an array of Float64 | ||
@test length(counts) == n_bins + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = [rand(Distributions.Normal(1, 1.0),6) for _ in 1:5] | |
counts = empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=n_bins) | |
@test typeof(counts) == Array{Float64, 1} # Check if the output is an array of Float64 | |
@test length(counts) == n_bins + 1 | |
sampled_distributions = [rand(Distributions.Normal(1, 1.0), 6) for _ in 1:5] | |
counts = empirical_frequency_regression(Y_cal, sampled_distributions; n_bins=n_bins) | |
@test typeof(counts) == Array{Float64,1} # Check if the output is an array of Float64 | |
@test length(counts) == n_bins + 1 |
test/calibration.jl
Outdated
sampled_distributions = [rand(Distributions.Normal(1, 1.0),6) for _ in 1:5] | ||
@test_throws ArgumentError empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = [rand(Distributions.Normal(1, 1.0),6) for _ in 1:5] | |
@test_throws ArgumentError empirical_frequency_regression(Y_cal, sampled_distributions, n_bins=0) | |
sampled_distributions = [rand(Distributions.Normal(1, 1.0), 6) for _ in 1:5] | |
@test_throws ArgumentError empirical_frequency_regression( | |
Y_cal, sampled_distributions, n_bins=0 | |
) |
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
@testset "empirical_frequency_binary_classification tests" begin | ||
# Test 1: Check that the function runs without errors and returns an array for a simple case | ||
y_binary = rand(0:1, 10) | ||
sampled_distributions = rand(2,10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = rand(2,10) | |
sampled_distributions = rand(2, 10) |
test/calibration.jl
Outdated
y_binary = rand(0:1, 10) | ||
sampled_distributions = rand(2,10) | ||
n_bins = 4 | ||
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins=n_bins) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification(y_binary, sampled_distributions, n_bins=n_bins) | |
num_p_per_interval, emp_avg, bin_centers = empirical_frequency_binary_classification( | |
y_binary, sampled_distributions; n_bins=n_bins | |
) |
test/calibration.jl
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
test/calibration.jl
Outdated
Y_cal = [0, 1, 0, 1.2, 4] | ||
sampled_distributions = rand(2,5) | ||
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=10) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Y_cal = [0, 1, 0, 1.2, 4] | |
sampled_distributions = rand(2,5) | |
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=10) | |
Y_cal = [0, 1, 0, 1.2, 4] | |
sampled_distributions = rand(2, 5) | |
@test_throws ArgumentError empirical_frequency_binary_classification( | |
Y_cal, sampled_distributions, n_bins=10 | |
) |
test/calibration.jl
Outdated
sampled_distributions = rand(2,5) | ||
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=0) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
sampled_distributions = rand(2,5) | |
@test_throws ArgumentError empirical_frequency_binary_classification(Y_cal, sampled_distributions, n_bins=0) | |
end | |
sampled_distributions = rand(2, 5) | |
@test_throws ArgumentError empirical_frequency_binary_classification( | |
Y_cal, sampled_distributions, n_bins=0 | |
) | |
end |
I know, I just forget to do it. I wish it was automatically applied before synchronization |
…ct ( will address it in a different issue) fixed docstring for calibration.jl
added functions in utils.jl for computing the average empirical frequency and changed glm_predictive_distribution to return a normal distribution in case of regression. it should pass all tests