From f007cf4e64e9cdc3d0b74334d924df62dba01521 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Sun, 24 Mar 2024 18:38:46 +0100 Subject: [PATCH 01/12] don't understand these errors rn --- Project.toml | 7 +++++-- test/Project.toml | 2 ++ test/aqua.jl | 8 ++++++++ test/runtests.jl | 1 + 4 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 test/aqua.jl diff --git a/Project.toml b/Project.toml index e9b0d4c..d35fefa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaijaPlotting" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" authors = ["Patrick Altmeyer"] -version = "1.0.8" +version = "1.1.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -20,6 +20,7 @@ NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" [compat] +Aqua = "0.8" CategoricalArrays = "0.10" ConformalPrediction = "0.1" CounterfactualExplanations = "0.1" @@ -34,10 +35,12 @@ MultivariateStats = "0.9, 0.10" NaturalSort = "1" NearestNeighborModels = "0.2" Plots = "1" +Test = "1.7, 1.8, 1.9, 1.10" julia = "1.7, 1.8, 1.9, 1.10" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Aqua", "Test"] diff --git a/test/Project.toml b/test/Project.toml index e86e793..faab27e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -11,6 +12,7 @@ TaijaData = "9d524318-b4e6-4a65-86d2-b2b72d07866c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +Aqua = "0.8" ConformalPrediction = "0.1" CounterfactualExplanations = "0.1" Flux = "0.12, 0.13" diff --git a/test/aqua.jl b/test/aqua.jl new file mode 100644 index 0000000..8adc0d1 --- /dev/null +++ b/test/aqua.jl @@ -0,0 +1,8 @@ +using Aqua + +@testset "Aqua.jl" begin + # Ambiguities needs to be tested seperately until the bug in Aqua package (https://github.com/JuliaTesting/Aqua.jl/issues/77) is fixed + Aqua.test_ambiguities([TaijaPlotting]; recursive=false, broken=false) + + Aqua.test_all(TaijaPlotting; ambiguities=false) +end diff --git a/test/runtests.jl b/test/runtests.jl index 98b0638..e5be431 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using TaijaPlotting using Test @testset "TaijaPlotting.jl" begin + include("aqua.jl") include("ConformalPrediction.jl") include("CounterfactualExplanations.jl") include("LaplaceRedux.jl") From 3e19c926295d49fbd448e30d761c02e4ea10700b Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Sep 2024 15:11:51 +0200 Subject: [PATCH 02/12] slowly slowly --- Project.toml | 4 +- src/CounterfactualExplations/data.jl | 10 +- src/CounterfactualExplations/models.jl | 174 +++++++++++++++++++------ src/LaplaceRedux/LaplaceRedux.jl | 4 +- src/TaijaPlotting.jl | 5 +- 5 files changed, 152 insertions(+), 45 deletions(-) diff --git a/Project.toml b/Project.toml index 07d7d2a..b01155b 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" @@ -17,7 +16,9 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1" [compat] @@ -26,7 +27,6 @@ ConformalPrediction = "0.1, 1" CounterfactualExplanations = "1.1.5" DataAPI = "1" Distributions = "0.25" -Flux = "0.12, 0.13, 0.14" LaplaceRedux = "0.1, 0.2, 1.0.1" LinearAlgebra = "1.10" MLJBase = "0.21, 0.22, 1" diff --git a/src/CounterfactualExplations/data.jl b/src/CounterfactualExplations/data.jl index fabc294..bae9ddc 100644 --- a/src/CounterfactualExplations/data.jl +++ b/src/CounterfactualExplations/data.jl @@ -53,8 +53,14 @@ function prepare_for_plotting(data::CounterfactualData; dim_red::Symbol = :pca) return X', y, multi_dim end -function Plots.scatter!(data::CounterfactualData; dim_red::Symbol = :pca, kwargs...) +@recipe function f(data::CounterfactualData; dim_red=:pca) + + # Set up: X, y, _ = prepare_for_plotting(data; dim_red = dim_red) _c = Int.(y.refs) - return Plots.scatter!(X[:, 1], X[:, 2]; group = y, colour = _c, kwargs...) + group := y + markercolor := _c + + # return data + return X[:, 1], X[:, 2] end diff --git a/src/CounterfactualExplations/models.jl b/src/CounterfactualExplations/models.jl index 19fbf16..f170b77 100644 --- a/src/CounterfactualExplations/models.jl +++ b/src/CounterfactualExplations/models.jl @@ -2,28 +2,27 @@ using DataAPI using Distributions: pdf using NearestNeighborModels: KNNClassifier -function Plots.plot( + + +@recipe function f( M::AbstractFittedModel, - data::DataPreprocessing.CounterfactualData; - target::Union{Nothing,RawTargetType} = nothing, - colorbar = true, - title = "", - length_out = 100, - zoom = -0.1, - xlims = nothing, - ylims = nothing, - linewidth = 0.1, - alpha = 1.0, - contour_alpha = 1.0, - dim_red::Symbol = :pca, - plot_loss::Bool = false, - loss_fun::Union{Nothing,Function} = nothing, - kwargs..., + data::CounterfactualData; + target=nothing, + length_out=100, + zoom=-0.1, + dim_red=:pca, + plot_loss=false, + loss_fun=nothing, ) + + # Get user-defined arguments: + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + X, _ = DataPreprocessing.unpack_data(data) ŷ = probs(M, X) # true predictions if size(ŷ, 1) > 1 - ŷ = vec(Flux.onecold(ŷ, 1:size(ŷ, 1))) + ŷ = vec(OneHotArrays.onecold(ŷ, 1:size(ŷ, 1))) else ŷ = vec(ŷ) end @@ -35,7 +34,7 @@ function Plots.plot( end target_encoded = data.output_encoder(target) - X, y, multi_dim = prepare_for_plotting(data; dim_red = dim_red) + X, y, multi_dim = prepare_for_plotting(data; dim_red=dim_red) # Surface range: zoom = zoom * maximum(abs.(X)) @@ -44,17 +43,18 @@ function Plots.plot( else xlims = xlims .+ (zoom, -zoom) end + if isnothing(ylims) ylims = (minimum(X[:, 2]), maximum(X[:, 2])) .+ (zoom, -zoom) else ylims = ylims .+ (zoom, -zoom) end - x_range = convert.(eltype(X), range(xlims[1]; stop = xlims[2], length = length_out)) - y_range = convert.(eltype(X), range(ylims[1]; stop = ylims[2], length = length_out)) + x_range = convert.(eltype(X), range(xlims[1]; stop=xlims[2], length=length_out)) + y_range = convert.(eltype(X), range(ylims[1]; stop=ylims[2], length=length_out)) plot_loss = plot_loss || !isnothing(loss_fun) - if plot_loss + if plot_loss # Loss surface: Z = [loss_fun(logits(M, [x, y][:, :]), target_encoded) for x in x_range, y in y_range] else @@ -84,24 +84,124 @@ function Plots.plot( target_idx = get_target_index(data.y_levels, target) z = plot_loss ? Z[1, :] : Z[target_idx, :] - # Contour: - Plots.contourf( - x_range, - y_range, - z; - colorbar = colorbar, - title = title, - linewidth = linewidth, - xlims = xlims, - ylims = ylims, - kwargs..., - alpha = contour_alpha, - ) - - # Samples: - return Plots.scatter!(data; dim_red = dim_red, alpha = alpha, kwargs...) + return x_range, y_range, z + end +@userplot struct ModelPlot{T<:Tuple{AbstractModel,CounterfactualData}} + args::T +end + +@recipe function f(mp::ModelPlot) + model = mp.args[1] + data = mp.args[2] + plt = contourf(model, data) + scatter!(data) + display(plt) + return nothing +end + + +# function Plots.plot( +# M::AbstractFittedModel, +# data::DataPreprocessing.CounterfactualData; +# target::Union{Nothing,RawTargetType} = nothing, +# colorbar = true, +# title = "", +# length_out = 100, +# zoom = -0.1, +# xlims = nothing, +# ylims = nothing, +# linewidth = 0.1, +# alpha = 1.0, +# contour_alpha = 1.0, +# dim_red::Symbol = :pca, +# plot_loss::Bool = false, +# loss_fun::Union{Nothing,Function} = nothing, +# kwargs..., +# ) +# X, _ = DataPreprocessing.unpack_data(data) +# ŷ = probs(M, X) # true predictions +# if size(ŷ, 1) > 1 +# ŷ = vec(Flux.onecold(ŷ, 1:size(ŷ, 1))) +# else +# ŷ = vec(ŷ) +# end + +# # Target: +# if isnothing(target) +# target = data.y_levels[1] +# @info "No target label supplied, using first." +# end +# target_encoded = data.output_encoder(target) + +# X, y, multi_dim = prepare_for_plotting(data; dim_red = dim_red) + +# # Surface range: +# zoom = zoom * maximum(abs.(X)) +# if isnothing(xlims) +# xlims = (minimum(X[:, 1]), maximum(X[:, 1])) .+ (zoom, -zoom) +# else +# xlims = xlims .+ (zoom, -zoom) +# end +# if isnothing(ylims) +# ylims = (minimum(X[:, 2]), maximum(X[:, 2])) .+ (zoom, -zoom) +# else +# ylims = ylims .+ (zoom, -zoom) +# end +# x_range = convert.(eltype(X), range(xlims[1]; stop = xlims[2], length = length_out)) +# y_range = convert.(eltype(X), range(ylims[1]; stop = ylims[2], length = length_out)) + +# plot_loss = plot_loss || !isnothing(loss_fun) + +# if plot_loss +# # Loss surface: +# Z = [loss_fun(logits(M, [x, y][:, :]), target_encoded) for x in x_range, y in y_range] +# else +# # Prediction surface: +# if multi_dim +# knn1, y_train = voronoi(X, ŷ) +# predict_ = +# (X::AbstractVector) -> vec( +# pdf( +# MLJBase.predict(knn1, MLJBase.table(reshape(X, 1, 2))), +# DataAPI.levels(y_train), +# ), +# ) +# Z = [predict_([x, y]) for x in x_range, y in y_range] +# else +# predict_ = function (X::AbstractVector) +# X = permutedims(permutedims(X)) +# z = predict_proba(M, data, X) +# return z +# end +# Z = [predict_([x, y]) for x in x_range, y in y_range] +# end +# end + +# # Pre-processes: +# Z = reduce(hcat, Z) +# target_idx = get_target_index(data.y_levels, target) +# z = plot_loss ? Z[1, :] : Z[target_idx, :] + +# # Contour: +# Plots.contourf( +# x_range, +# y_range, +# z; +# colorbar = colorbar, +# title = title, +# linewidth = linewidth, +# xlims = xlims, +# ylims = ylims, +# kwargs..., +# alpha = contour_alpha, +# ) + +# # Samples: +# return Plots.scatter!(data; dim_red = dim_red, alpha = alpha, kwargs...) +# end + function voronoi(X::AbstractMatrix, y::AbstractVector) knnc = KNNClassifier(; K = 1) # KNNClassifier instantiation X = MLJBase.table(X) diff --git a/src/LaplaceRedux/LaplaceRedux.jl b/src/LaplaceRedux/LaplaceRedux.jl index b8fa737..b5395dc 100644 --- a/src/LaplaceRedux/LaplaceRedux.jl +++ b/src/LaplaceRedux/LaplaceRedux.jl @@ -125,7 +125,7 @@ function Plots.plot( end """ -'Calibration_Plot_Regression(y_cal, samp_distr, n_bins)' + calibration_plot(y_cal, samp_distr, n_bins) This plot displays the true frequency of points in each confidence interval relative to the predicted fraction of points in that interval. The intervals are taken in step of 0.05 quantiles. @@ -136,7 +136,7 @@ Input: -'samp_distr': an array of sampled distributions F(x_t) corresponding to the y_t stacked column-wise. -'n_bins': numbers of bins to use. """ -function Calibration_Plot(la::Laplace, y_cal, samp_distr; n_bins = 20) +function calibration_plot(la::Laplace, y_cal, samp_distr; n_bins = 20) quantiles = collect(range(0; stop = 1, length = n_bins + 1)) # Create a new plot object p = plot() diff --git a/src/TaijaPlotting.jl b/src/TaijaPlotting.jl index e37a519..722571a 100644 --- a/src/TaijaPlotting.jl +++ b/src/TaijaPlotting.jl @@ -2,16 +2,17 @@ module TaijaPlotting using CategoricalArrays using CounterfactualExplanations -using Flux using MLJBase using MultivariateStats +using OneHotArrays using Plots +using RecipesBase export animate_path include("ConformalPrediction/ConformalPrediction.jl") include("CounterfactualExplations/CounterfactualExplanations.jl") include("LaplaceRedux/LaplaceRedux.jl") -export Calibration_Plot +export calibration_plot end From 99751451776ade08e3fd0a0ab4101cb1e91cfb19 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Tue, 3 Sep 2024 19:17:28 +0200 Subject: [PATCH 03/12] some progress on recipes --- src/CounterfactualExplations/models.jl | 29 ++++++++++++-------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/CounterfactualExplations/models.jl b/src/CounterfactualExplations/models.jl index f170b77..054e476 100644 --- a/src/CounterfactualExplations/models.jl +++ b/src/CounterfactualExplations/models.jl @@ -2,8 +2,6 @@ using DataAPI using Distributions: pdf using NearestNeighborModels: KNNClassifier - - @recipe function f( M::AbstractFittedModel, data::CounterfactualData; @@ -84,21 +82,20 @@ using NearestNeighborModels: KNNClassifier target_idx = get_target_index(data.y_levels, target) z = plot_loss ? Z[1, :] : Z[target_idx, :] - return x_range, y_range, z - -end - -@userplot struct ModelPlot{T<:Tuple{AbstractModel,CounterfactualData}} - args::T -end + @series begin + seriestype := :contourf + colorbar := :none + x_range, y_range, z + end -@recipe function f(mp::ModelPlot) - model = mp.args[1] - data = mp.args[2] - plt = contourf(model, data) - scatter!(data) - display(plt) - return nothing + _c = Int.(y.refs) + @series begin + seriestype := :scatter + group := y + markercolor := _c + X[:,1], X[:,2] + end + end From 4097a0c498d1bee5e2f631df65ad5a986f066297 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 08:19:01 +0200 Subject: [PATCH 04/12] model plot sorted --- src/CounterfactualExplations/models.jl | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/CounterfactualExplations/models.jl b/src/CounterfactualExplations/models.jl index 054e476..2bd40bc 100644 --- a/src/CounterfactualExplations/models.jl +++ b/src/CounterfactualExplations/models.jl @@ -13,10 +13,16 @@ using NearestNeighborModels: KNNClassifier loss_fun=nothing, ) + # Asserts + @assert !plot_loss || !isnothing(loss_fun) "Need to provide a loss function to plot the loss, e.g. (`loss_fun=Flux.Losses.logitcrossentropy`)." + # Get user-defined arguments: xlims = get(plotattributes, :xlims, nothing) ylims = get(plotattributes, :ylims, nothing) + # Plot attributes + linewidth --> 0.1 + X, _ = DataPreprocessing.unpack_data(data) ŷ = probs(M, X) # true predictions if size(ŷ, 1) > 1 @@ -50,6 +56,9 @@ using NearestNeighborModels: KNNClassifier x_range = convert.(eltype(X), range(xlims[1]; stop=xlims[2], length=length_out)) y_range = convert.(eltype(X), range(ylims[1]; stop=ylims[2], length=length_out)) + xlims --> xlims + ylims --> ylims + plot_loss = plot_loss || !isnothing(loss_fun) if plot_loss @@ -82,18 +91,21 @@ using NearestNeighborModels: KNNClassifier target_idx = get_target_index(data.y_levels, target) z = plot_loss ? Z[1, :] : Z[target_idx, :] + # Contour plot: @series begin seriestype := :contourf - colorbar := :none x_range, y_range, z end - _c = Int.(y.refs) - @series begin - seriestype := :scatter - group := y - markercolor := _c - X[:,1], X[:,2] + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + X[group_idx, 1], X[group_idx, 2] + end end end From 3b030a24d6dafb49bfd67daae16720758d2f0f58 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 10:36:15 +0200 Subject: [PATCH 05/12] CE plot kind of working --- .../counterfactuals.jl | 171 +++++++++++---- src/CounterfactualExplations/data.jl | 23 ++- src/CounterfactualExplations/models.jl | 194 ++++++------------ src/LaplaceRedux/LaplaceRedux.jl | 2 +- 4 files changed, 206 insertions(+), 184 deletions(-) diff --git a/src/CounterfactualExplations/counterfactuals.jl b/src/CounterfactualExplations/counterfactuals.jl index f5a705f..806db33 100644 --- a/src/CounterfactualExplations/counterfactuals.jl +++ b/src/CounterfactualExplations/counterfactuals.jl @@ -1,39 +1,66 @@ using MLUtils: stack -""" - Plots.plot( - ce::CounterfactualExplanation; - alpha_ = 0.5, - plot_up_to::Union{Nothing,Int} = nothing, - plot_proba::Bool = false, - kwargs..., - ) - -Calling `plot` on an instance of type `CounterfactualExplanation` returns a plot that visualises the entire counterfactual path. For multi-dimensional input data, the data is first compressed into two dimensions. The decision boundary is then approximated using using a Nearest Neighbour classifier. This is still somewhat experimental at the moment. - - -# Examples - -```julia-repl -# Search: -generator = GenericGenerator() -ce = generate_counterfactual(x, target, counterfactual_data, M, generator) - -plot(ce) -``` -""" -function Plots.plot( - ce_plot::CounterfactualExplanation; - alpha_ = 0.5, +@recipe function f( + ce::CounterfactualExplanation; + target=nothing, + length_out=100, + zoom=-0.1, + dim_red=:pca, + plot_loss=false, + loss_fun=nothing, plot_up_to::Union{Nothing,Int} = nothing, plot_proba::Bool = false, n_points = 1000, - kwargs..., ) - ce = deepcopy(ce_plot) + ce = deepcopy(ce) ce.data = DataPreprocessing.subsample(ce.data, n_points) + # Asserts + @assert !plot_loss || !isnothing(loss_fun) "Need to provide a loss function to plot the loss, e.g. (`loss_fun=Flux.Losses.logitcrossentropy`)." + + # Get user-defined arguments: + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + + # Plot attributes + linewidth --> 0.1 + + contour_series, X, y, xlims, ylims = setup_model_plot( + ce.M, + ce.data, + target, + length_out, + zoom, + dim_red, + plot_loss, + loss_fun, + xlims, + ylims, + ) + + xlims --> xlims + ylims --> ylims + + # Contour plot: + @series begin + seriestype := :contourf + contour_series[1], contour_series[2], contour_series[3] + end + + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + X[group_idx, 1], X[group_idx, 2] + end + end + + alpha = get(plotattributes, :alpha, 0.5) + max_iter = total_steps(ce) max_iter = if isnothing(plot_up_to) total_steps(ce) @@ -41,21 +68,76 @@ function Plots.plot( minimum([plot_up_to, max_iter]) end max_iter += 1 - ingredients = set_up_plots(ce; alpha = alpha_, plot_proba = plot_proba, kwargs...) + ingredients = set_up_plots(ce; alpha = alpha, plot_proba = plot_proba) - for t = 1:max_iter - final_state = t == max_iter - plot_state(ce, t, final_state; ingredients...) + for X in eachslice(ingredients.path_embedded, dims=3) + for (x,y) in zip(eachcol(X),ingredients.path_labels) + @series begin + seriestype := :scatter + markercolor := CategoricalArrays.levelcode.(y[1]) + label := :none + x[1,:], X[2,:] + end + end end +end - plt = if plot_proba - Plots.plot(ingredients.p1, ingredients.p2; kwargs...) - else - Plots.plot(ingredients.p1; kwargs...) - end +# """ +# Plots.plot( +# ce::CounterfactualExplanation; +# alpha_ = 0.5, +# plot_up_to::Union{Nothing,Int} = nothing, +# plot_proba::Bool = false, +# kwargs..., +# ) - return plt -end +# Calling `plot` on an instance of type `CounterfactualExplanation` returns a plot that visualises the entire counterfactual path. For multi-dimensional input data, the data is first compressed into two dimensions. The decision boundary is then approximated using using a Nearest Neighbour classifier. This is still somewhat experimental at the moment. + + +# # Examples + +# ```julia-repl +# # Search: +# generator = GenericGenerator() +# ce = generate_counterfactual(x, target, counterfactual_data, M, generator) + +# plot(ce) +# ``` +# """ +# function Plots.plot( +# ce_plot::CounterfactualExplanation; +# alpha_ = 0.5, +# plot_up_to::Union{Nothing,Int} = nothing, +# plot_proba::Bool = false, +# n_points = 1000, +# kwargs..., +# ) + +# ce = deepcopy(ce_plot) +# ce.data = DataPreprocessing.subsample(ce.data, n_points) + +# max_iter = total_steps(ce) +# max_iter = if isnothing(plot_up_to) +# total_steps(ce) +# else +# minimum([plot_up_to, max_iter]) +# end +# max_iter += 1 +# ingredients = set_up_plots(ce; alpha = alpha_, plot_proba = plot_proba, kwargs...) + +# for t = 1:max_iter +# final_state = t == max_iter +# plot_state(ce, t, final_state; ingredients...) +# end + +# plt = if plot_proba +# Plots.plot(ingredients.p1, ingredients.p2; kwargs...) +# else +# Plots.plot(ingredients.p1; kwargs...) +# end + +# return plt +# end """ animate_path(ce::CounterfactualExplanation, path=tempdir(); plot_proba::Bool=false, kwargs...) @@ -80,6 +162,9 @@ function animate_path( plot_proba::Bool = false, kwargs..., ) + + alpha = get(plotattributes, :alpha, 0.5) + max_iter = total_steps(ce) max_iter = if isnothing(plot_up_to) total_steps(ce) @@ -87,7 +172,7 @@ function animate_path( minimum([plot_up_to, max_iter]) end max_iter += 1 - ingredients = set_up_plots(ce; alpha = alpha_, plot_proba = plot_proba, kwargs...) + ingredients = set_up_plots(ce; alpha = alpha, plot_proba = plot_proba, kwargs...) anim = @animate for t = 1:max_iter final_state = t == max_iter @@ -167,16 +252,16 @@ end A helper method that prepares data for plotting. """ function set_up_plots(ce::CounterfactualExplanation; alpha, plot_proba, kwargs...) - p1 = plot(ce.M, ce.data; target = ce.target, alpha = alpha, kwargs...) - p2 = plot(; xlims = (1, total_steps(ce) + 1), ylims = (0, 1)) + # p1 = plot(ce.M, ce.data; target = ce.target, alpha = alpha, kwargs...) + # p2 = plot(; xlims = (1, total_steps(ce) + 1), ylims = (0, 1)) path_embedded = embed_path(ce) path_labels = CounterfactualExplanations.counterfactual_label_path(ce) y_levels = ce.data.y_levels path_labels = map(x -> CategoricalArrays.categorical(x; levels = y_levels), path_labels) path_probs = CounterfactualExplanations.target_probs_path(ce) output = ( - p1 = p1, - p2 = p2, + # p1 = p1, + # p2 = p2, path_embedded = path_embedded, path_labels = path_labels, path_probs = path_probs, diff --git a/src/CounterfactualExplations/data.jl b/src/CounterfactualExplations/data.jl index bae9ddc..92e3b7b 100644 --- a/src/CounterfactualExplations/data.jl +++ b/src/CounterfactualExplations/data.jl @@ -12,9 +12,9 @@ function embed(data::CounterfactualData, X::AbstractArray = nothing; dim_red::Sy else @info "Training model to compress data." if dim_red == :pca - tfn = MultivariateStats.fit(PCA, X_train; maxoutdim=2) + tfn = MultivariateStats.fit(PCA, X_train; maxoutdim = 2) elseif dim_red == :tsne - tfn = MultivariateStats.fit(TSNE, X_train; maxoutdim=2) + tfn = MultivariateStats.fit(TSNE, X_train; maxoutdim = 2) end data.input_encoder = nothing X = isnothing(X) ? X_train : X @@ -22,7 +22,7 @@ function embed(data::CounterfactualData, X::AbstractArray = nothing; dim_red::Sy end # Transforming: - X = typeof(X) <: Vector{<:Matrix} ? MLUtils.stack(X, dims=2) : X + X = typeof(X) <: Vector{<:Matrix} ? MLUtils.stack(X, dims = 2) : X if !isnothing(tfn) && !isnothing(X) X = mapslices(x -> MultivariateStats.predict(tfn, x), X, dims = 1) else @@ -53,14 +53,19 @@ function prepare_for_plotting(data::CounterfactualData; dim_red::Symbol = :pca) return X', y, multi_dim end -@recipe function f(data::CounterfactualData; dim_red=:pca) +@recipe function f(data::CounterfactualData; dim_red = :pca) # Set up: X, y, _ = prepare_for_plotting(data; dim_red = dim_red) - _c = Int.(y.refs) - group := y - markercolor := _c - # return data - return X[:, 1], X[:, 2] + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + X[group_idx, 1], X[group_idx, 2] + end + end end diff --git a/src/CounterfactualExplations/models.jl b/src/CounterfactualExplations/models.jl index 2bd40bc..0661720 100644 --- a/src/CounterfactualExplations/models.jl +++ b/src/CounterfactualExplations/models.jl @@ -4,13 +4,13 @@ using NearestNeighborModels: KNNClassifier @recipe function f( M::AbstractFittedModel, - data::CounterfactualData; - target=nothing, - length_out=100, - zoom=-0.1, - dim_red=:pca, - plot_loss=false, - loss_fun=nothing, + data::CounterfactualData; + target = nothing, + length_out = 100, + zoom = -0.1, + dim_red = :pca, + plot_loss = false, + loss_fun = nothing, ) # Asserts @@ -23,6 +23,53 @@ using NearestNeighborModels: KNNClassifier # Plot attributes linewidth --> 0.1 + contour_series, X, y, xlims, ylims = setup_model_plot( + M, + data, + target, + length_out, + zoom, + dim_red, + plot_loss, + loss_fun, + xlims, + ylims, + ) + + xlims --> xlims + ylims --> ylims + + # Contour plot: + @series begin + seriestype := :contourf + contour_series[1], contour_series[2], contour_series[3] + end + + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + X[group_idx, 1], X[group_idx, 2] + end + end + +end + +function setup_model_plot( + M::AbstractFittedModel, + data::CounterfactualData, + target, + length_out, + zoom, + dim_red, + plot_loss, + loss_fun, + xlims, + ylims, +) X, _ = DataPreprocessing.unpack_data(data) ŷ = probs(M, X) # true predictions if size(ŷ, 1) > 1 @@ -38,7 +85,7 @@ using NearestNeighborModels: KNNClassifier end target_encoded = data.output_encoder(target) - X, y, multi_dim = prepare_for_plotting(data; dim_red=dim_red) + X, y, multi_dim = prepare_for_plotting(data; dim_red = dim_red) # Surface range: zoom = zoom * maximum(abs.(X)) @@ -53,17 +100,16 @@ using NearestNeighborModels: KNNClassifier else ylims = ylims .+ (zoom, -zoom) end - x_range = convert.(eltype(X), range(xlims[1]; stop=xlims[2], length=length_out)) - y_range = convert.(eltype(X), range(ylims[1]; stop=ylims[2], length=length_out)) - - xlims --> xlims - ylims --> ylims + x_range = convert.(eltype(X), range(xlims[1]; stop = xlims[2], length = length_out)) + y_range = convert.(eltype(X), range(ylims[1]; stop = ylims[2], length = length_out)) plot_loss = plot_loss || !isnothing(loss_fun) if plot_loss # Loss surface: - Z = [loss_fun(logits(M, [x, y][:, :]), target_encoded) for x in x_range, y in y_range] + Z = [ + loss_fun(logits(M, [x, y][:, :]), target_encoded) for x in x_range, y in y_range + ] else # Prediction surface: if multi_dim @@ -91,126 +137,12 @@ using NearestNeighborModels: KNNClassifier target_idx = get_target_index(data.y_levels, target) z = plot_loss ? Z[1, :] : Z[target_idx, :] - # Contour plot: - @series begin - seriestype := :contourf - x_range, y_range, z - end + # Collect: + contour_series = (x_range, y_range, z) - # Scatter plot: - for (i, x) in enumerate(unique(sort(y))) - @series begin - seriestype := :scatter - markercolor := i - group_idx = findall(y .== x) - label --> "$(x)" - X[group_idx, 1], X[group_idx, 2] - end - end - + return contour_series, X, y, xlims, ylims end - -# function Plots.plot( -# M::AbstractFittedModel, -# data::DataPreprocessing.CounterfactualData; -# target::Union{Nothing,RawTargetType} = nothing, -# colorbar = true, -# title = "", -# length_out = 100, -# zoom = -0.1, -# xlims = nothing, -# ylims = nothing, -# linewidth = 0.1, -# alpha = 1.0, -# contour_alpha = 1.0, -# dim_red::Symbol = :pca, -# plot_loss::Bool = false, -# loss_fun::Union{Nothing,Function} = nothing, -# kwargs..., -# ) -# X, _ = DataPreprocessing.unpack_data(data) -# ŷ = probs(M, X) # true predictions -# if size(ŷ, 1) > 1 -# ŷ = vec(Flux.onecold(ŷ, 1:size(ŷ, 1))) -# else -# ŷ = vec(ŷ) -# end - -# # Target: -# if isnothing(target) -# target = data.y_levels[1] -# @info "No target label supplied, using first." -# end -# target_encoded = data.output_encoder(target) - -# X, y, multi_dim = prepare_for_plotting(data; dim_red = dim_red) - -# # Surface range: -# zoom = zoom * maximum(abs.(X)) -# if isnothing(xlims) -# xlims = (minimum(X[:, 1]), maximum(X[:, 1])) .+ (zoom, -zoom) -# else -# xlims = xlims .+ (zoom, -zoom) -# end -# if isnothing(ylims) -# ylims = (minimum(X[:, 2]), maximum(X[:, 2])) .+ (zoom, -zoom) -# else -# ylims = ylims .+ (zoom, -zoom) -# end -# x_range = convert.(eltype(X), range(xlims[1]; stop = xlims[2], length = length_out)) -# y_range = convert.(eltype(X), range(ylims[1]; stop = ylims[2], length = length_out)) - -# plot_loss = plot_loss || !isnothing(loss_fun) - -# if plot_loss -# # Loss surface: -# Z = [loss_fun(logits(M, [x, y][:, :]), target_encoded) for x in x_range, y in y_range] -# else -# # Prediction surface: -# if multi_dim -# knn1, y_train = voronoi(X, ŷ) -# predict_ = -# (X::AbstractVector) -> vec( -# pdf( -# MLJBase.predict(knn1, MLJBase.table(reshape(X, 1, 2))), -# DataAPI.levels(y_train), -# ), -# ) -# Z = [predict_([x, y]) for x in x_range, y in y_range] -# else -# predict_ = function (X::AbstractVector) -# X = permutedims(permutedims(X)) -# z = predict_proba(M, data, X) -# return z -# end -# Z = [predict_([x, y]) for x in x_range, y in y_range] -# end -# end - -# # Pre-processes: -# Z = reduce(hcat, Z) -# target_idx = get_target_index(data.y_levels, target) -# z = plot_loss ? Z[1, :] : Z[target_idx, :] - -# # Contour: -# Plots.contourf( -# x_range, -# y_range, -# z; -# colorbar = colorbar, -# title = title, -# linewidth = linewidth, -# xlims = xlims, -# ylims = ylims, -# kwargs..., -# alpha = contour_alpha, -# ) - -# # Samples: -# return Plots.scatter!(data; dim_red = dim_red, alpha = alpha, kwargs...) -# end - function voronoi(X::AbstractMatrix, y::AbstractVector) knnc = KNNClassifier(; K = 1) # KNNClassifier instantiation X = MLJBase.table(X) diff --git a/src/LaplaceRedux/LaplaceRedux.jl b/src/LaplaceRedux/LaplaceRedux.jl index b5395dc..fd23ee6 100644 --- a/src/LaplaceRedux/LaplaceRedux.jl +++ b/src/LaplaceRedux/LaplaceRedux.jl @@ -88,7 +88,7 @@ function Plots.plot( # Plot predict_ = function (X::AbstractVector) - z = LaplaceRedux.predict(la,X; link_approx = link_approx) + z = LaplaceRedux.predict(la, X; link_approx = link_approx) if LaplaceRedux.outdim(la) == 1 # binary z = [1.0 - z[1], z[1]] end From 22a8362d31e6ba5c03b4d13efbd391e249610e6c Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 11:07:42 +0200 Subject: [PATCH 06/12] ce plot sorted --- .../counterfactuals.jl | 108 ++++-------------- 1 file changed, 21 insertions(+), 87 deletions(-) diff --git a/src/CounterfactualExplations/counterfactuals.jl b/src/CounterfactualExplations/counterfactuals.jl index 806db33..469996c 100644 --- a/src/CounterfactualExplations/counterfactuals.jl +++ b/src/CounterfactualExplations/counterfactuals.jl @@ -9,7 +9,6 @@ using MLUtils: stack plot_loss=false, loss_fun=nothing, plot_up_to::Union{Nothing,Int} = nothing, - plot_proba::Bool = false, n_points = 1000, ) @@ -22,6 +21,9 @@ using MLUtils: stack # Get user-defined arguments: xlims = get(plotattributes, :xlims, nothing) ylims = get(plotattributes, :ylims, nothing) + ms = get(plotattributes, :markersize, 3) + mspath = ms*2 + msfinal = mspath*2 # Plot attributes linewidth --> 0.1 @@ -53,14 +55,13 @@ using MLUtils: stack @series begin seriestype := :scatter markercolor := i + markersize := ms group_idx = findall(y .== x) label --> "$(x)" X[group_idx, 1], X[group_idx, 2] end end - alpha = get(plotattributes, :alpha, 0.5) - max_iter = total_steps(ce) max_iter = if isnothing(plot_up_to) total_steps(ce) @@ -68,77 +69,27 @@ using MLUtils: stack minimum([plot_up_to, max_iter]) end max_iter += 1 - ingredients = set_up_plots(ce; alpha = alpha, plot_proba = plot_proba) - - for X in eachslice(ingredients.path_embedded, dims=3) - for (x,y) in zip(eachcol(X),ingredients.path_labels) + path_x, path_y = setup_ce_plot(ce) + + # Outer loop over number of counterfactuals: + for (num_counterfactual, X) in enumerate(eachslice(path_x, dims=3)) + # Inner loop over counterfactual search steps: + steps = zip(eachcol(X), path_y) + for (i,(x,y)) in enumerate(steps) + i <= max_iter || break + _annotate = i == length(steps) && ce.num_counterfactuals > 1 @series begin seriestype := :scatter - markercolor := CategoricalArrays.levelcode.(y[1]) + markercolor := CategoricalArrays.levelcode.(y[num_counterfactual]) + markersize := i == length(steps) ? msfinal : mspath + series_annotation := _annotate ? text("C$(num_counterfactual)", mspath) : nothing label := :none - x[1,:], X[2,:] + x[1,:], x[2,:] end end end end -# """ -# Plots.plot( -# ce::CounterfactualExplanation; -# alpha_ = 0.5, -# plot_up_to::Union{Nothing,Int} = nothing, -# plot_proba::Bool = false, -# kwargs..., -# ) - -# Calling `plot` on an instance of type `CounterfactualExplanation` returns a plot that visualises the entire counterfactual path. For multi-dimensional input data, the data is first compressed into two dimensions. The decision boundary is then approximated using using a Nearest Neighbour classifier. This is still somewhat experimental at the moment. - - -# # Examples - -# ```julia-repl -# # Search: -# generator = GenericGenerator() -# ce = generate_counterfactual(x, target, counterfactual_data, M, generator) - -# plot(ce) -# ``` -# """ -# function Plots.plot( -# ce_plot::CounterfactualExplanation; -# alpha_ = 0.5, -# plot_up_to::Union{Nothing,Int} = nothing, -# plot_proba::Bool = false, -# n_points = 1000, -# kwargs..., -# ) - -# ce = deepcopy(ce_plot) -# ce.data = DataPreprocessing.subsample(ce.data, n_points) - -# max_iter = total_steps(ce) -# max_iter = if isnothing(plot_up_to) -# total_steps(ce) -# else -# minimum([plot_up_to, max_iter]) -# end -# max_iter += 1 -# ingredients = set_up_plots(ce; alpha = alpha_, plot_proba = plot_proba, kwargs...) - -# for t = 1:max_iter -# final_state = t == max_iter -# plot_state(ce, t, final_state; ingredients...) -# end - -# plt = if plot_proba -# Plots.plot(ingredients.p1, ingredients.p2; kwargs...) -# else -# Plots.plot(ingredients.p1; kwargs...) -# end - -# return plt -# end - """ animate_path(ce::CounterfactualExplanation, path=tempdir(); plot_proba::Bool=false, kwargs...) @@ -242,31 +193,14 @@ Base.@kwdef struct PlotIngredients end """ - set_up_plots( - ce::CounterfactualExplanation; - alpha, - plot_proba, - kwargs... - ) + setup_ce_plot(ce::CounterfactualExplanation) A helper method that prepares data for plotting. """ -function set_up_plots(ce::CounterfactualExplanation; alpha, plot_proba, kwargs...) - # p1 = plot(ce.M, ce.data; target = ce.target, alpha = alpha, kwargs...) - # p2 = plot(; xlims = (1, total_steps(ce) + 1), ylims = (0, 1)) +function setup_ce_plot(ce::CounterfactualExplanation) path_embedded = embed_path(ce) path_labels = CounterfactualExplanations.counterfactual_label_path(ce) y_levels = ce.data.y_levels - path_labels = map(x -> CategoricalArrays.categorical(x; levels = y_levels), path_labels) - path_probs = CounterfactualExplanations.target_probs_path(ce) - output = ( - # p1 = p1, - # p2 = p2, - path_embedded = path_embedded, - path_labels = path_labels, - path_probs = path_probs, - alpha = alpha, - plot_proba = plot_proba, - ) - return output + path_labels = map(x -> CategoricalArrays.categorical(x; levels=y_levels), path_labels) + return path_embedded, path_labels end From 6454afa044c991d322aa7a0b2082233573fc0f59 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 11:30:03 +0200 Subject: [PATCH 07/12] animate_path also working --- .../counterfactuals.jl | 96 ++++--------------- 1 file changed, 20 insertions(+), 76 deletions(-) diff --git a/src/CounterfactualExplations/counterfactuals.jl b/src/CounterfactualExplations/counterfactuals.jl index 469996c..359a4de 100644 --- a/src/CounterfactualExplations/counterfactuals.jl +++ b/src/CounterfactualExplations/counterfactuals.jl @@ -1,5 +1,3 @@ -using MLUtils: stack - @recipe function f( ce::CounterfactualExplanation; target=nothing, @@ -9,18 +7,28 @@ using MLUtils: stack plot_loss=false, loss_fun=nothing, plot_up_to::Union{Nothing,Int} = nothing, - n_points = 1000, + n_points = nothing, ) - ce = deepcopy(ce) - ce.data = DataPreprocessing.subsample(ce.data, n_points) + if !isnothing(n_points) + if n_points < size(ce.data.X, 2) + @info "Undersampling to $(n_points) points." + else + @info "Oversampling to $(n_points) points." + end + xlims, ylims = extrema(ce.data.X[1, :]), extrema(ce.data.X[2, :]) + ce = deepcopy(ce) + ce.data = DataPreprocessing.subsample(ce.data, n_points) + else + xlims, ylims = nothing, nothing + end # Asserts @assert !plot_loss || !isnothing(loss_fun) "Need to provide a loss function to plot the loss, e.g. (`loss_fun=Flux.Losses.logitcrossentropy`)." # Get user-defined arguments: - xlims = get(plotattributes, :xlims, nothing) - ylims = get(plotattributes, :ylims, nothing) + xlims = get(plotattributes, :xlims, xlims) + ylims = get(plotattributes, :ylims, ylims) ms = get(plotattributes, :markersize, 3) mspath = ms*2 msfinal = mspath*2 @@ -77,11 +85,12 @@ using MLUtils: stack steps = zip(eachcol(X), path_y) for (i,(x,y)) in enumerate(steps) i <= max_iter || break + _final_iter = i == length(steps) || i == max_iter _annotate = i == length(steps) && ce.num_counterfactuals > 1 @series begin seriestype := :scatter markercolor := CategoricalArrays.levelcode.(y[num_counterfactual]) - markersize := i == length(steps) ? msfinal : mspath + markersize := _final_iter ? msfinal : mspath series_annotation := _annotate ? text("C$(num_counterfactual)", mspath) : nothing label := :none x[1,:], x[2,:] @@ -108,14 +117,11 @@ animate_path(ce) function animate_path( ce::CounterfactualExplanation, path = tempdir(); - alpha_ = 0.5, plot_up_to::Union{Nothing,Int} = nothing, - plot_proba::Bool = false, - kwargs..., + legend = :topright, + kwrgs..., ) - alpha = get(plotattributes, :alpha, 0.5) - max_iter = total_steps(ce) max_iter = if isnothing(plot_up_to) total_steps(ce) @@ -123,75 +129,13 @@ function animate_path( minimum([plot_up_to, max_iter]) end max_iter += 1 - ingredients = set_up_plots(ce; alpha = alpha, plot_proba = plot_proba, kwargs...) anim = @animate for t = 1:max_iter - final_state = t == max_iter - plot_state(ce, t, final_state; ingredients...) - if plot_proba - plot(ingredients.p1, ingredients.p2; kwargs...) - else - plot(ingredients.p1; kwargs...) - end + plot(ce; plot_up_to=t, legend=legend, kwrgs...) end return anim end -""" - plot_state( - ce::CounterfactualExplanation, - t::Int, - final_state::Bool; - kwargs... - ) - -Helper function that plots a single step of the counterfactual path. -""" -function plot_state(ce::CounterfactualExplanation, t::Int, final_state::Bool; kwargs...) - args = PlotIngredients(; kwargs...) - x1 = args.path_embedded[1, t, :] - x2 = args.path_embedded[2, t, :] - y = args.path_labels[t] - _c = CategoricalArrays.levelcode.(y) - n_ = ce.num_counterfactuals - label_ = reshape(["C$i" for i = 1:n_], 1, n_) - if !final_state - scatter!(args.p1, x1, x2; group = y, colour = _c, ms = 5, label = "") - else - scatter!(args.p1, x1, x2; group = y, colour = _c, ms = 10, label = "") - if n_ > 1 - label_1 = vec([text(lab, 5) for lab in label_]) - annotate!(x1, x2, label_1) - end - end - if args.plot_proba - probs_ = reshape(reduce(vcat, args.path_probs[1:t]), t, n_) - if t == 1 && n_ > 1 - label_2 = label_ - else - label_2 = "" - end - plot!( - args.p2, - probs_; - label = label_2, - color = reshape(1:n_, 1, n_), - title = "p(y=$(ce.target))", - ) - end -end - -"A container used for plotting." -Base.@kwdef struct PlotIngredients - p1::Any - p2::Any - path_embedded::Any - path_labels::Any - path_probs::Any - alpha::Any - plot_proba::Any -end - """ setup_ce_plot(ce::CounterfactualExplanation) From 04fabcdfbb04ca6ff80d3e43b35cf5cc19f55d40 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 13:03:26 +0200 Subject: [PATCH 08/12] Laplace also done it seems --- src/LaplaceRedux/LaplaceRedux.jl | 227 +++++++++++++++++-------------- 1 file changed, 128 insertions(+), 99 deletions(-) diff --git a/src/LaplaceRedux/LaplaceRedux.jl b/src/LaplaceRedux/LaplaceRedux.jl index fd23ee6..5d17224 100644 --- a/src/LaplaceRedux/LaplaceRedux.jl +++ b/src/LaplaceRedux/LaplaceRedux.jl @@ -1,127 +1,156 @@ using LaplaceRedux using Trapz -function Plots.plot( +@recipe function f( la::Laplace, X::AbstractArray, y::AbstractArray; - link_approx::Symbol = :probit, - target::Union{Nothing,Real} = nothing, - colorbar = true, - title = nothing, - length_out = 50, - zoom = -1, - xlims = nothing, - ylims = nothing, - linewidth = 0.1, - lw = 4, - kwargs..., + link_approx::Symbol=:probit, + target::Union{Nothing,Real}=nothing, + length_out=50, + zoom=-1, ) + + # Asserts: if la.likelihood == :regression @assert size(X, 1) == 1 "Cannot plot regression for multiple input variables." else @assert size(X, 1) == 2 "Cannot plot classification for more than two input variables." end + # Get user-defined arguments: + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + title = get(plotattributes, :title, nothing) + + # Plot attributes + lw = get(plotattributes, :linewidth, 1) + lw_yhat = lw*2 + lw_contour = lw*0.1 + if la.likelihood == :regression - # REGRESSION + xrange, yrange, xlims, ylims = surface_range(X, y, xlims, ylims, zoom, length_out) + xlims := xlims + ylims := ylims - # Surface range: - if isnothing(xlims) - xlims = (minimum(X), maximum(X)) .+ (zoom, -zoom) - else - xlims = xlims .+ (zoom, -zoom) - end - if isnothing(ylims) - ylims = (minimum(y), maximum(y)) .+ (zoom, -zoom) - else - ylims = ylims .+ (zoom, -zoom) - end - x_range = range(xlims[1]; stop = xlims[2], length = length_out) - y_range = range(ylims[1]; stop = ylims[2], length = length_out) - - title = isnothing(title) ? "" : title - - # Plot: - scatter( - vec(X), - vec(y); - label = "ytrain", - xlim = xlims, - ylim = ylims, - lw = lw, - title = title, - kwargs..., - ) - _x = collect(x_range)[:, :]' - normal_distr, fμ, fvar = LaplaceRedux.glm_predictive_distribution(la, _x) + # Plot predictions: + _x = collect(xrange)[:, :]' + fμ, fvar = LaplaceRedux.predict(la, _x) fμ = vec(fμ) fσ = vec(sqrt.(fvar)) - pred_std = sqrt.(fσ .^ 2 .+ la.prior.σ^2) - plot!( - x_range, - fμ; - color = 2, - label = "yhat", - ribbon = (1.96 * pred_std, 1.96 * pred_std), - lw = lw, - kwargs..., - ) # the specific values 1.96 are used here to create a 95% confidence interval - else - - # CLASSIFICATION - - # Surface range: - if isnothing(xlims) - xlims = (minimum(X[1, :]), maximum(X[1, :])) .+ (zoom, -zoom) - else - xlims = xlims .+ (zoom, -zoom) + @series begin + seriestype := :path + ribbon := (1.96 * fσ, 1.96 * fσ) + linewidth := lw_yhat + label --> "yhat" + xrange, fμ end - if isnothing(ylims) - ylims = (minimum(X[2, :]), maximum(X[2, :])) .+ (zoom, -zoom) - else - ylims = ylims .+ (zoom, -zoom) + + # Scatter training data: + @series begin + seriestype := :scatter + label --> "ytrain" + vec(X), vec(y) end - x_range = range(xlims[1]; stop = xlims[2], length = length_out) - y_range = range(ylims[1]; stop = ylims[2], length = length_out) - - # Plot - predict_ = function (X::AbstractVector) - z = LaplaceRedux.predict(la, X; link_approx = link_approx) - if LaplaceRedux.outdim(la) == 1 # binary - z = [1.0 - z[1], z[1]] - end - return z + + end + + if la.likelihood == :classification + + xrange, yrange, xlims, ylims = surface_range(X, xlims, ylims, zoom, length_out) + xlims := xlims + ylims := ylims + + Z, target, title = get_contour(la, xrange, yrange, link_approx, target, title) + + # Contour plot: + @series begin + seriestype := :contourf + linewidth := lw_contour + title --> title + xrange, yrange, Z[Int(target), :] end - Z = [predict_([x, y]) for x in x_range, y in y_range] - Z = reduce(hcat, Z) - if LaplaceRedux.outdim(la) > 1 - if isnothing(target) - @info "No target label supplied, using first." + + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + println(group_idx) + label --> "$(x)" + X[1, group_idx], X[2, group_idx] end - target = isnothing(target) ? 1 : target - title = isnothing(title) ? "p̂(y=$(target))" : title - else - target = isnothing(target) ? 2 : target - title = isnothing(title) ? "p̂(y=$(target-1))" : title end - # Contour: - contourf( - x_range, - y_range, - Z[Int(target), :]; - colorbar = colorbar, - title = title, - linewidth = linewidth, - xlims = xlims, - ylims = ylims, - kwargs..., - ) - # Samples: - scatter!(X[1, :], X[2, :]; group = Int.(y), color = Int.(y), kwargs...) end + +end + +function surface_range( + X::AbstractArray, y::AbstractArray, + xlims,ylims,zoom,length_out, +) + + # Surface range: + if isnothing(xlims) + xlims = (minimum(X), maximum(X)) .+ (zoom, -zoom) + else + xlims = xlims .+ (zoom, -zoom) + end + if isnothing(ylims) + ylims = (minimum(y), maximum(y)) .+ (zoom, -zoom) + else + ylims = ylims .+ (zoom, -zoom) + end + x_range = range(xlims[1]; stop = xlims[2], length = length_out) + y_range = range(ylims[1]; stop = ylims[2], length = length_out) + return x_range, y_range, xlims, ylims + +end + +function surface_range(X::AbstractArray,xlims,ylims,zoom,length_out) + + if isnothing(xlims) + xlims = (minimum(X[1, :]), maximum(X[1, :])) .+ (zoom, -zoom) + else + xlims = xlims .+ (zoom, -zoom) + end + if isnothing(ylims) + ylims = (minimum(X[2, :]), maximum(X[2, :])) .+ (zoom, -zoom) + else + ylims = ylims .+ (zoom, -zoom) + end + x_range = range(xlims[1]; stop = xlims[2], length = length_out) + y_range = range(ylims[1]; stop = ylims[2], length = length_out) + + return x_range, y_range, xlims, ylims +end + +function get_contour(la::Laplace, x_range, y_range, link_approx, target, title) + + predict_ = function (la, X::AbstractVector) + z = LaplaceRedux.predict(la, X; link_approx = link_approx) + if LaplaceRedux.outdim(la) == 1 # binary + z = [1.0 - z[1], z[1]] + end + return z + end + Z = [predict_(la, [x, y]) for x in x_range, y in y_range] + Z = reduce(hcat, Z) + if LaplaceRedux.outdim(la) > 1 + if isnothing(target) + @info "No target label supplied, using first." + end + target = isnothing(target) ? 1 : target + title = isnothing(title) ? "p̂(y=$(target))" : title + else + target = isnothing(target) ? 2 : target + title = isnothing(title) ? "p̂(y=$(target-1))" : title + end + + return Z, target, title end """ From 81a97b7a7e6d1e292e8282da6dd2cc7ebafb8c37 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 13:05:33 +0200 Subject: [PATCH 09/12] now done --- src/LaplaceRedux/LaplaceRedux.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/LaplaceRedux/LaplaceRedux.jl b/src/LaplaceRedux/LaplaceRedux.jl index 5d17224..6f00246 100644 --- a/src/LaplaceRedux/LaplaceRedux.jl +++ b/src/LaplaceRedux/LaplaceRedux.jl @@ -78,7 +78,6 @@ using Trapz seriestype := :scatter markercolor := i group_idx = findall(y .== x) - println(group_idx) label --> "$(x)" X[1, group_idx], X[2, group_idx] end From 8c0f0567de65e95cb3a3893809809807d85c337c Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 14:34:57 +0200 Subject: [PATCH 10/12] this should be all of them --- docs/Project.toml | 2 + .../ConformalPrediction.jl | 307 +----------------- src/ConformalPrediction/bar.jl | 23 ++ src/ConformalPrediction/classification.jl | 192 +++++++++++ src/ConformalPrediction/regression.jl | 77 +++++ 5 files changed, 297 insertions(+), 304 deletions(-) create mode 100644 src/ConformalPrediction/bar.jl create mode 100644 src/ConformalPrediction/classification.jl create mode 100644 src/ConformalPrediction/regression.jl diff --git a/docs/Project.toml b/docs/Project.toml index 6a60890..41da921 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,8 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" +NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" TaijaPlotting = "bd7198b4-c7d6-400c-9bab-9a24614b0240" diff --git a/src/ConformalPrediction/ConformalPrediction.jl b/src/ConformalPrediction/ConformalPrediction.jl index e4df1f1..3d77911 100644 --- a/src/ConformalPrediction/ConformalPrediction.jl +++ b/src/ConformalPrediction/ConformalPrediction.jl @@ -38,307 +38,6 @@ function get_names(X) return _names end -@doc raw""" - Plots.contourf(conf_model::ConformalModel,fitresult,X,y;kwargs...) - -A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of a fitted conformal classifier with exactly two input variable. Data (`X`,`y`) are plotted as dots and overlaid with predictions sets. `y` is used to indicate the ground-truth labels of samples by colour. Samples are visualized in a two-dimensional feature space, so it is expected that `X` ``\in \mathcal{R}^2``. By default, a contour is used to visualize the softmax output of the conformal classifier for the target label, where `target` indicates can be used to define the index of the target label. Transparent regions indicate that the prediction set does not include the `target` label. - -## Target - -In the binary case, `target` defaults to `2`, indexing the second label: assuming the labels are `[0,1]` then the softmax output for `1` is shown. In the multi-class cases, `target` defaults to the first class: for example, if the labels are `["🐶", "🐱", "🐭"]` (in that order) then the contour indicates the softmax output for `"🐶"`. - -## Set Size - -If `plot_set_size` is set to `true`, then the contour instead visualises the the set size. - -## Univariate and Higher Dimensional Inputs - -For univariate of multiple inputs (>2), this function is not applicable. See [`Plots.areaplot(conf_model::ConformalProbabilisticSet, fitresult, X, y; kwargs...)`](@ref) for an alternative way to visualize prediction for any conformal classifier. - -""" -function Plots.contourf( - conf_model::ConformalProbabilisticSet, - fitresult, - X, - y; - target::Union{Nothing,Real} = nothing, - ntest = 50, - zoom = -1, - xlims = nothing, - ylims = nothing, - plot_set_size = false, - plot_classification_loss = false, - plot_set_loss = false, - temp = 0.1, - κ = 0, - loss_matrix = UniformScaling(1.0), - kwargs..., -) - - # Setup: - X = permutedims(MLJBase.matrix(X)) - - @assert size(X, 1) == 2 "Can only create contour plot for conformal classifier with exactly two input variables." - - x1 = X[1, :] - x2 = X[2, :] - - # Plot limits: - xlims, ylims = generate_lims(x1, x2, xlims, ylims, zoom) - - # Surface range: - x1range = range(xlims[1]; stop = xlims[2], length = ntest) - x2range = range(ylims[1]; stop = ylims[2], length = ntest) - - # Target - if !isnothing(target) - @assert target in levels(y) "Specified target does not match any of the labels." - end - if length(unique(y)) > 1 - if isnothing(target) - @info "No target label supplied, using first." - end - target = isnothing(target) ? levels(y)[1] : target - if plot_set_size - _default_title = "Set size" - elseif plot_set_loss - _default_title = "Smooth set loss" - elseif plot_classification_loss - _default_title = "ℒ(C,$(target))" - else - _default_title = "p̂(y=$(target))" - end - else - if plot_set_size - _default_title = "Set size" - elseif plot_set_loss - _default_title = "Smooth set loss" - elseif plot_classification_loss - _default_title = "ℒ(C,$(target-1))" - else - _default_title = "p̂(y=$(target-1))" - end - end - title = !@isdefined(title) ? _default_title : title - - # Predictions - Z = [] - for x2 in x2range, x1 in x1range - p̂ = MLJBase.predict(conf_model, fitresult, table([x1 x2]))[1] - if plot_set_size - z = ismissing(p̂) ? 0 : sum(pdf.(p̂, p̂.decoder.classes) .> 0) - elseif plot_classification_loss - _target = categorical([target]; levels = levels(y)) - z = ConformalPrediction.ConformalTraining.classification_loss( - conf_model, - fitresult, - [x1 x2], - _target; - temp = temp, - loss_matrix = loss_matrix, - ) - elseif plot_set_loss - z = ConformalPrediction.ConformalTraining.smooth_size_loss( - conf_model, - fitresult, - [x1 x2]; - κ = κ, - temp = temp, - ) - else - z = ismissing(p̂) ? [missing for i = 1:length(levels(y))] : pdf.(p̂, levels(y)) - z = replace(z, 0 => missing) - end - push!(Z, z) - end - Z = reduce(hcat, Z) - Z = Z[findall(levels(y) .== target)[1][1], :] - - # Contour: - if plot_set_size - _n = length(unique(y)) - clim = (0, _n) - plt = contourf( - x1range, - x2range, - Z; - title = title, - xlims = xlims, - ylims = ylims, - c = cgrad(:blues, _n + 1; categorical = true), - clim = clim, - kwargs..., - ) - else - plt = contourf( - x1range, - x2range, - Z; - title = title, - xlims = xlims, - ylims = ylims, - c = cgrad(:blues), - linewidth = 0, - kwargs..., - ) - end - - # Samples: - y = typeof(y) <: CategoricalArrays.CategoricalArray ? y : Int.(y) - return scatter!(plt, x1, x2; group = y, kwargs...) -end - -""" - Plots.areaplot( - conf_model::ConformalProbabilisticSet, fitresult, X, y; - input_var::Union{Nothing,Int,Symbol}=nothing, - kwargs... - ) - -A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of any fitted conformal classifier. Using a stacked area chart, this function plots the softmax output(s) contained the the conformal predictions set on the vertical axis against an input variable `X` on the horizontal axis. In the case of multiple input variables, the `input_var` argument can be used to specify the desired input variable. -""" -function Plots.areaplot( - conf_model::ConformalProbabilisticSet, - fitresult, - X, - y; - input_var::Union{Nothing,Int,Symbol} = nothing, - kwargs..., -) - - # Setup: - Xraw = deepcopy(X) - _names = get_names(Xraw) - X = permutedims(MLJBase.matrix(X)) - - # Dimensions: - if size(X, 1) > 1 - if isnothing(input_var) - @info "Multiple inputs no input variable (`input_var`) specified: defaulting to first variable." - idx = 1 - else - if typeof(input_var) == Int - idx = input_var - else - @assert input_var ∈ _names "$(input_var) is not among the variable names of `X`." - idx = findall(_names .== input_var)[1] - end - end - x = X[idx, :] - else - idx = 1 - x = X - end - - # Predictions: - ŷ = MLJBase.predict(conf_model, fitresult, Xraw) - nout = length(levels(y)) - ŷ = - map(_y -> ismissing(_y) ? [0 for i = 1:nout] : pdf.(_y, levels(y)), ŷ) |> _y -> reduce(hcat, _y) - ŷ = permutedims(ŷ) - - return areaplot(x, ŷ; kwargs...) -end - -""" - Plots.plot( - conf_model::ConformalInterval, fitresult, X, y; - kwrgs... - ) - -A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of a fitted conformal regressor. Data (`X`,`y`) are plotted as dots and overlaid with predictions intervals. `y` is plotted on the vertical axis against a single variable `X` on the horizontal axis. A shaded area indicates the prediction interval. The line in the center of the interval is the midpoint of the interval and can be interpreted as the point estimate of the conformal regressor. In case `X` is multi-dimensional, `input_var` can be used to specify the input variable of interest that will be used for the horizontal axis. If unspecified, the first variable will be plotting by default. -""" -function Plots.plot( - conf_model::ConformalInterval, - fitresult, - X, - y; - input_var::Union{Nothing,Int,Symbol} = nothing, - xlims::Union{Nothing,Tuple} = nothing, - ylims::Union{Nothing,Tuple} = nothing, - zoom::Real = -0.5, - train_lab::Union{Nothing,String} = nothing, - test_lab::Union{Nothing,String} = nothing, - ymid_lw::Int = 1, - kwargs..., -) - - # Setup - title = !@isdefined(title) ? "" : title - train_lab = isnothing(train_lab) ? "Observed" : train_lab - test_lab = isnothing(test_lab) ? "Predicted" : test_lab - - Xraw = deepcopy(X) - _names = get_names(Xraw) - X = permutedims(MLJBase.matrix(X)) - - # Dimensions: - if size(X, 1) > 1 - if isnothing(input_var) - @info "Multivariate input for regression with no input variable (`input_var`) specified: defaulting to first variable." - idx = 1 - else - if typeof(input_var) == Int - idx = input_var - else - @assert input_var ∈ _names "$(input_var) is not among the variable names of `X`." - idx = findall(_names .== input_var)[1] - end - end - x = X[idx, :] - else - idx = 1 - x = X - end - - # Plot limits: - xlims, ylims = generate_lims(x, y, xlims, ylims, zoom) - - # Plot training data: - plt = scatter( - vec(x), - vec(y); - label = train_lab, - xlim = xlims, - ylim = ylims, - title = title, - kwargs..., - ) - - # Plot predictions: - ŷ = MLJBase.predict(conf_model, fitresult, Xraw) - lb, ub = eachcol(reduce(vcat, map(y -> permutedims(collect(y)), ŷ))) - ymid = (lb .+ ub) ./ 2 - yerror = (ub .- lb) ./ 2 - xplot = vec(x) - _idx = sortperm(xplot) - return plot!( - plt, - xplot[_idx], - ymid[_idx]; - label = test_lab, - ribbon = (yerror, yerror), - lw = ymid_lw, - kwargs..., - ) -end - -""" - Plots.bar(conf_model::ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...) - -A `Plots.jl` recipe/method extension that can be used to visualize the set size distribution of a conformal predictor. In the regression case, prediction interval widths are stratified into discrete bins. It can be useful to plot the distribution of set sizes in order to visually asses how adaptive a conformal predictor is. For more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs”. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should vary across the spectrum of ‘empty set’ to ‘all labels included’. -""" -function Plots.bar( - conf_model::ConformalModel, - fitresult, - X; - label = "", - xtickfontsize = 6, - kwrgs..., -) - ŷ = MLJBase.predict(conf_model, fitresult, X) - idx = ConformalPrediction.size_indicator(ŷ) - x = sort(levels(idx); lt = natural) - y = [sum(idx .== _x) for _x in x] - return Plots.bar(x, y; label = label, xtickfontsize = xtickfontsize, kwrgs...) -end +include("regression.jl") +include("bar.jl") +include("classification.jl") \ No newline at end of file diff --git a/src/ConformalPrediction/bar.jl b/src/ConformalPrediction/bar.jl new file mode 100644 index 0000000..dbb6bc7 --- /dev/null +++ b/src/ConformalPrediction/bar.jl @@ -0,0 +1,23 @@ +@recipe function f( + conf_model::ConformalModel, + fitresult, + X +) + + # Plot attributes: + xtickfontsize --> 6 + + # Setup: + ŷ = MLJBase.predict(conf_model, fitresult, X) + idx = ConformalPrediction.size_indicator(ŷ) + x = sort(levels(idx); lt=natural) + y = [sum(idx .== _x) for _x in x] + + # Bar chart + @series begin + seriestype := :bar + label --> "" + x, y + end + +end \ No newline at end of file diff --git a/src/ConformalPrediction/classification.jl b/src/ConformalPrediction/classification.jl new file mode 100644 index 0000000..568e957 --- /dev/null +++ b/src/ConformalPrediction/classification.jl @@ -0,0 +1,192 @@ +@recipe function f( + conf_model::ConformalProbabilisticSet, + fitresult, + X, + y; + input_var=nothing, + target=nothing, + ntest=50, + zoom=-1, + plot_set_size=false, + plot_classification_loss=false, + plot_set_loss=false, + temp=0.1, + κ=0, + loss_matrix=UniformScaling(1.0), +) + + # Get user-defined arguments: + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + + if size(permutedims(MLJBase.matrix(X)), 1) > 2 + + # AREA PLOT FOR MULTI-D + + # Plot attributes: + xtickfontsize --> 6 + + # Setup: + Xraw = deepcopy(X) + _names = get_names(Xraw) + X = permutedims(MLJBase.matrix(X)) + + # Dimensions: + if size(X, 1) > 1 + if isnothing(input_var) + @info "Multiple inputs no input variable (`input_var`) specified: defaulting to first variable." + idx = 1 + else + if typeof(input_var) == Int + idx = input_var + else + @assert input_var ∈ _names "$(input_var) is not among the variable names of `X`." + idx = findall(_names .== input_var)[1] + end + end + x = X[idx, :] + else + idx = 1 + x = X + end + + # Predictions: + ŷ = MLJBase.predict(conf_model, fitresult, Xraw) + nout = length(levels(y)) + ŷ = + map(_y -> ismissing(_y) ? [0 for i = 1:nout] : pdf.(_y, levels(y)), ŷ) |> _y -> reduce(hcat, _y) + ŷ = permutedims(ŷ) + println(x) + println(ŷ[sortperm(x), :]) + + # Area chart + args = (x, ŷ) + data = cumsum(args[end], dims=2) + x = length(args) == 1 ? (axes(data, 1)) : args[1] + seriestype := :line + for i in axes(data, 2) + @series begin + fillrange := i > 1 ? data[:, i-1] : 0 + x, data[:, i] + end + end + + elseif size(permutedims(MLJBase.matrix(X)), 1) == 2 + + # CONTOUR PLOT FOR 2D + + # Setup: + x1, x2, x1range, x2range, Z, xlims, ylims, _default_title = setup_contour_cp( + conf_model, fitresult, X, y, xlims, ylims, zoom, ntest, target, + plot_set_size, plot_classification_loss, plot_set_loss, temp, κ, loss_matrix, + ) + + # Contour: + _n = length(unique(y)) + clim = (0, _n) + @series begin + seriestype := :contourf + x1range, x2range, Z + end + + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + x1[group_idx], x2[group_idx] + end + end + + end + +end + +function setup_contour_cp( + conf_model, fitresult, X, y, xlims, ylims, zoom, ntest, target, + plot_set_size, + plot_classification_loss, + plot_set_loss, + temp, + κ, + loss_matrix, +) + X = permutedims(MLJBase.matrix(X)) + + x1 = X[1, :] + x2 = X[2, :] + + # Plot limits: + xlims, ylims = generate_lims(x1, x2, xlims, ylims, zoom) + + # Surface range: + x1range = range(xlims[1]; stop=xlims[2], length=ntest) + x2range = range(ylims[1]; stop=ylims[2], length=ntest) + + # Target + if !isnothing(target) + @assert target in levels(y) "Specified target does not match any of the labels." + end + if length(unique(y)) > 1 + if isnothing(target) + @info "No target label supplied, using first." + end + target = isnothing(target) ? levels(y)[1] : target + if plot_set_size + _default_title = "Set size" + elseif plot_set_loss + _default_title = "Smooth set loss" + elseif plot_classification_loss + _default_title = "ℒ(C,$(target))" + else + _default_title = "p̂(y=$(target))" + end + else + if plot_set_size + _default_title = "Set size" + elseif plot_set_loss + _default_title = "Smooth set loss" + elseif plot_classification_loss + _default_title = "ℒ(C,$(target-1))" + else + _default_title = "p̂(y=$(target-1))" + end + end + + # Predictions + Z = [] + for x2 in x2range, x1 in x1range + p̂ = MLJBase.predict(conf_model, fitresult, table([x1 x2]))[1] + if plot_set_size + z = ismissing(p̂) ? 0 : sum(pdf.(p̂, p̂.decoder.classes) .> 0) + elseif plot_classification_loss + _target = categorical([target]; levels=levels(y)) + z = ConformalPrediction.ConformalTraining.classification_loss( + conf_model, + fitresult, + [x1 x2], + _target; + temp=temp, + loss_matrix=loss_matrix, + ) + elseif plot_set_loss + z = ConformalPrediction.ConformalTraining.smooth_size_loss( + conf_model, + fitresult, + [x1 x2]; + κ=κ, + temp=temp, + ) + else + z = ismissing(p̂) ? [missing for i = 1:length(levels(y))] : pdf.(p̂, levels(y)) + z = replace(z, 0 => missing) + end + push!(Z, z) + end + Z = reduce(hcat, Z) + Z = Z[findall(levels(y) .== target)[1][1], :] + + return x1, x2, x1range, x2range, Z, xlims, ylims, _default_title +end \ No newline at end of file diff --git a/src/ConformalPrediction/regression.jl b/src/ConformalPrediction/regression.jl new file mode 100644 index 0000000..09761be --- /dev/null +++ b/src/ConformalPrediction/regression.jl @@ -0,0 +1,77 @@ +@recipe function f( + conf_model::ConformalInterval, + fitresult, + X, + y; + input_var=nothing, + zoom=-0.5, + train_lab=nothing, + test_lab=nothing, +) + + # Get user-defined arguments: + train_lab = isnothing(train_lab) ? "Observed" : train_lab + test_lab = isnothing(test_lab) ? "Predicted" : test_lab + title = get(plotattributes, :xlims, "") + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + + # Plot attributes: + linewidth --> 1 + + # Setup: + x, y, xlims, ylims, Xraw = setup_ci(X, y, input_var, xlims, ylims, zoom) + + # Plot predictions: + ŷ = MLJBase.predict(conf_model, fitresult, Xraw) + lb, ub = eachcol(reduce(vcat, map(y -> permutedims(collect(y)), ŷ))) + ymid = (lb .+ ub) ./ 2 + yerror = (ub .- lb) ./ 2 + xplot = vec(x) + _idx = sortperm(xplot) + @series begin + seriestype := :path + ribbon := (yerror, yerror) + label := test_lab + xplot[_idx], ymid[_idx] + end + + # Scatter observed data: + @series begin + seriestype := :scatter + label := train_lab + vec(x), vec(y) + end + +end + +function setup_ci(X, y, input_var, xlims, ylims, zoom) + + Xraw = deepcopy(X) + _names = get_names(Xraw) + X = permutedims(MLJBase.matrix(X)) + + # Dimensions: + if size(X, 1) > 1 + if isnothing(input_var) + @info "Multivariate input for regression with no input variable (`input_var`) specified: defaulting to first variable." + idx = 1 + else + if typeof(input_var) == Int + idx = input_var + else + @assert input_var ∈ _names "$(input_var) is not among the variable names of `X`." + idx = findall(_names .== input_var)[1] + end + end + x = X[idx, :] + else + idx = 1 + x = X + end + + # Plot limits: + xlims, ylims = generate_lims(x, y, xlims, ylims, zoom) + + return x, y, xlims, ylims, Xraw +end \ No newline at end of file From 0a6890bafe3edd0d770b224d0ca72af49c6f8af5 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 15:09:29 +0200 Subject: [PATCH 11/12] now tests should hopefully pass --- test/ConformalPrediction.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ConformalPrediction.jl b/test/ConformalPrediction.jl index c0e1e1f..c80a947 100644 --- a/test/ConformalPrediction.jl +++ b/test/ConformalPrediction.jl @@ -21,9 +21,9 @@ isplot(plt) = typeof(plt) <: Plots.Plot fit!(mach, rows = train) @test isplot(bar(mach.model, mach.fitresult, X)) - @test isplot(areaplot(mach.model, mach.fitresult, X, y)) - @test isplot(areaplot(mach.model, mach.fitresult, X, y; input_var = 1)) - @test isplot(areaplot(mach.model, mach.fitresult, X, y; input_var = :x1)) + @test isplot(plot(mach.model, mach.fitresult, X, y)) + @test isplot(plot(mach.model, mach.fitresult, X, y; input_var = 1)) + @test isplot(plot(mach.model, mach.fitresult, X, y; input_var = :x1)) @test isplot(contourf(mach.model, mach.fitresult, X, y)) @test isplot( contourf(mach.model, mach.fitresult, X, y; zoom = -1, plot_set_size = true), From b8070471c2580c935f630e0afc3bcfef42846345 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 4 Sep 2024 15:44:22 +0200 Subject: [PATCH 12/12] docstrings adjusted --- Project.toml | 2 +- src/ConformalPrediction/bar.jl | 11 ++++- src/ConformalPrediction/classification.jl | 41 ++++++++++++++++++- src/ConformalPrediction/regression.jl | 16 +++++++- .../counterfactuals.jl | 19 ++++++++- src/CounterfactualExplations/data.jl | 7 +++- src/CounterfactualExplations/models.jl | 16 +++++++- src/LaplaceRedux/LaplaceRedux.jl | 30 ++++++++++---- 8 files changed, 125 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index b01155b..a050b53 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaijaPlotting" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" authors = ["Patrick Altmeyer"] -version = "1.2.0" +version = "1.3.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/ConformalPrediction/bar.jl b/src/ConformalPrediction/bar.jl index dbb6bc7..8503187 100644 --- a/src/ConformalPrediction/bar.jl +++ b/src/ConformalPrediction/bar.jl @@ -1,4 +1,13 @@ -@recipe function f( +""" + plot( + conf_model::ConformalModel, + fitresult, + X + ) + +A `Plots.jl` recipe that can be used to visualize the set size distribution of a conformal predictor. In the regression case, prediction interval widths are stratified into discrete bins. It can be useful to plot the distribution of set sizes in order to visually asses how adaptive a conformal predictor is. For more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs”. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should vary across the spectrum of ‘empty set’ to ‘all labels included’. +""" +@recipe function plot( conf_model::ConformalModel, fitresult, X diff --git a/src/ConformalPrediction/classification.jl b/src/ConformalPrediction/classification.jl index 568e957..ad966f1 100644 --- a/src/ConformalPrediction/classification.jl +++ b/src/ConformalPrediction/classification.jl @@ -1,4 +1,41 @@ -@recipe function f( +@doc raw""" + plot( + conf_model::ConformalProbabilisticSet, + fitresult, + X, + y; + input_var=nothing, + target=nothing, + ntest=50, + zoom=-1, + plot_set_size=false, + plot_classification_loss=false, + plot_set_loss=false, + temp=0.1, + κ=0, + loss_matrix=UniformScaling(1.0), + ) + +A `Plots.jl` recipe that can be used to visualize the conformal predictions of a fitted conformal classifier. + +## Two Dimensional Inputs + +Data (`X`,`y`) are plotted as dots and overlaid with predictions sets. `y` is used to indicate the ground-truth labels of samples by colour. Samples are visualized in a two-dimensional feature space, so it is expected that `X` ``\in \mathcal{R}^2``. By default, a contour is used to visualize the softmax output of the conformal classifier for the target label, where `target` indicates can be used to define the index of the target label. Transparent regions indicate that the prediction set does not include the `target` label. + +### Target + +In the binary case, `target` defaults to `2`, indexing the second label: assuming the labels are `[0,1]` then the softmax output for `1` is shown. In the multi-class cases, `target` defaults to the first class: for example, if the labels are `["🐶", "🐱", "🐭"]` (in that order) then the contour indicates the softmax output for `"🐶"`. + +### Set Size + +If `plot_set_size` is set to `true`, then the contour instead visualises the the set size. + +## Univariate and Higher Dimensional Inputs + +In the case of univariate inputs or higher dimensional inputs, a stacked area plot is created: in particular, this method plots the softmax output(s) contained the the conformal predictions set on the vertical axis against an input variable `X` on the horizontal axis. In the case of multiple input variables, the `input_var` argument can be used to specify the desired input variable. + +""" +@recipe function plot( conf_model::ConformalProbabilisticSet, fitresult, X, @@ -71,7 +108,7 @@ end end - elseif size(permutedims(MLJBase.matrix(X)), 1) == 2 + else # CONTOUR PLOT FOR 2D diff --git a/src/ConformalPrediction/regression.jl b/src/ConformalPrediction/regression.jl index 09761be..4734026 100644 --- a/src/ConformalPrediction/regression.jl +++ b/src/ConformalPrediction/regression.jl @@ -1,4 +1,18 @@ -@recipe function f( +""" + plot( + conf_model::ConformalInterval, + fitresult, + X, + y; + input_var=nothing, + zoom=-0.5, + train_lab=nothing, + test_lab=nothing, + ) + +A `Plots.jl` recipe that can be used to visualize the conformal predictions of a fitted conformal regressor. Data (`X`,`y`) are plotted as dots and overlaid with predictions intervals. `y` is plotted on the vertical axis against a single variable `X` on the horizontal axis. A shaded area indicates the prediction interval. The line in the center of the interval is the midpoint of the interval and can be interpreted as the point estimate of the conformal regressor. In case `X` is multi-dimensional, `input_var` can be used to specify the input variable of interest that will be used for the horizontal axis. If unspecified, the first variable will be plotting by default. +""" +@recipe function plot( conf_model::ConformalInterval, fitresult, X, diff --git a/src/CounterfactualExplations/counterfactuals.jl b/src/CounterfactualExplations/counterfactuals.jl index 359a4de..a2d2390 100644 --- a/src/CounterfactualExplations/counterfactuals.jl +++ b/src/CounterfactualExplations/counterfactuals.jl @@ -1,4 +1,19 @@ -@recipe function f( +""" + plot( + ce::CounterfactualExplanation; + target=nothing, + length_out=100, + zoom=-0.1, + dim_red=:pca, + plot_loss=false, + loss_fun=nothing, + plot_up_to = nothing, + n_points = nothing, + ) + +Calling `Plots.plot` on a `CounterfactualExplanation` object will plot the training data (scatters), model predictions for the specified `target` (contour) and the counterfactual path (scatter). +""" +@recipe function plot( ce::CounterfactualExplanation; target=nothing, length_out=100, @@ -6,7 +21,7 @@ dim_red=:pca, plot_loss=false, loss_fun=nothing, - plot_up_to::Union{Nothing,Int} = nothing, + plot_up_to = nothing, n_points = nothing, ) diff --git a/src/CounterfactualExplations/data.jl b/src/CounterfactualExplations/data.jl index 92e3b7b..c9b030d 100644 --- a/src/CounterfactualExplations/data.jl +++ b/src/CounterfactualExplations/data.jl @@ -53,7 +53,12 @@ function prepare_for_plotting(data::CounterfactualData; dim_red::Symbol = :pca) return X', y, multi_dim end -@recipe function f(data::CounterfactualData; dim_red = :pca) +""" + plot(data::CounterfactualData; dim_red = :pca) + +Calling `Plots.plot` on a `data::CounterfactualData` object will generate a scatter plot of the data. +""" +@recipe function plot(data::CounterfactualData; dim_red = :pca) # Set up: X, y, _ = prepare_for_plotting(data; dim_red = dim_red) diff --git a/src/CounterfactualExplations/models.jl b/src/CounterfactualExplations/models.jl index 0661720..707de47 100644 --- a/src/CounterfactualExplations/models.jl +++ b/src/CounterfactualExplations/models.jl @@ -2,7 +2,21 @@ using DataAPI using Distributions: pdf using NearestNeighborModels: KNNClassifier -@recipe function f( +""" + function plot( + M::AbstractFittedModel, + data::CounterfactualData; + target = nothing, + length_out = 100, + zoom = -0.1, + dim_red = :pca, + plot_loss = false, + loss_fun = nothing, + ) + +Calling `Plots.plot` on a `AbstractFittedModel` will plot the model's predictions as a contour. The `target` argument can be used to plot the predictions for a specific target variable. The `length_out` argument can be used to control the number of points used to plot the contour. The `zoom` argument can be used to control the zoom of the plot. The `dim_red` argument can be used to control the method used to reduce the dimensionality of the data if it has more than two features. +""" +@recipe function plot( M::AbstractFittedModel, data::CounterfactualData; target = nothing, diff --git a/src/LaplaceRedux/LaplaceRedux.jl b/src/LaplaceRedux/LaplaceRedux.jl index 6f00246..5120669 100644 --- a/src/LaplaceRedux/LaplaceRedux.jl +++ b/src/LaplaceRedux/LaplaceRedux.jl @@ -1,12 +1,25 @@ using LaplaceRedux using Trapz -@recipe function f( +""" + plot( + la::Laplace, + X::AbstractArray, + y::AbstractArray; + link_approx=:probit, + target=nothing, + length_out=50, + zoom=-1, + ) + +Calling `Plots.plot` on a `Laplace` object will plot the posterior predictive distribution and the training data. +""" +@recipe function plot( la::Laplace, X::AbstractArray, y::AbstractArray; - link_approx::Symbol=:probit, - target::Union{Nothing,Real}=nothing, + link_approx=:probit, + target=nothing, length_out=50, zoom=-1, ) @@ -158,11 +171,12 @@ end This plot displays the true frequency of points in each confidence interval relative to the predicted fraction of points in that interval. The intervals are taken in step of 0.05 quantiles. -Input: --'la::Laplace': the laplace model to use. --'Y_cal': a vector of true values y_t. --'samp_distr': an array of sampled distributions F(x_t) corresponding to the y_t stacked column-wise. --'n_bins': numbers of bins to use. +## Inputs + +- `la::Laplace` -- the laplace model to use. +- `Y_cal` -- a vector of true values y_t. +- `samp_distr` -- an array of sampled distributions F(x_t) corresponding to the y_t stacked column-wise. +- `n_bins` -- numbers of bins to use. """ function calibration_plot(la::Laplace, y_cal, samp_distr; n_bins = 20) quantiles = collect(range(0; stop = 1, length = n_bins + 1))