diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 0000000..c743950 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "blue" \ No newline at end of file diff --git a/Project.toml b/Project.toml index 07d7d2a..b3c7d0c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaijaPlotting" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" authors = ["Patrick Altmeyer"] -version = "1.2.0" +version = "1.3.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -9,7 +9,6 @@ ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" @@ -17,16 +16,18 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1" [compat] +Aqua = "0.8" CategoricalArrays = "0.10" ConformalPrediction = "0.1, 1" CounterfactualExplanations = "1.1.5" DataAPI = "1" Distributions = "0.25" -Flux = "0.12, 0.13, 0.14" LaplaceRedux = "0.1, 0.2, 1.0.1" LinearAlgebra = "1.10" MLJBase = "0.21, 0.22, 1" @@ -34,12 +35,16 @@ MLUtils = "0.4" MultivariateStats = "0.9, 0.10" NaturalSort = "1" NearestNeighborModels = "0.2" +OneHotArrays = "0.2.5" Plots = "1" +RecipesBase = "1.3.4" +Test = "1.10" Trapz = "2.0.3" julia = "1.10" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Aqua", "Test"] diff --git a/README.md b/README.md index 4ef9387..19ebb5d 100644 --- a/README.md +++ b/README.md @@ -5,5 +5,6 @@ [![Build Status](https://github.com/JuliaTrustworthyAI/TaijaPlotting.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaTrustworthyAI/TaijaPlotting.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/JuliaTrustworthyAI/TaijaPlotting.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaTrustworthyAI/TaijaPlotting.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) A package for plotting custom symbols from Taija packages. \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml index 6a60890..41da921 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,8 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" +NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" TaijaPlotting = "bd7198b4-c7d6-400c-9bab-9a24614b0240" diff --git a/src/ConformalPrediction/ConformalPrediction.jl b/src/ConformalPrediction/ConformalPrediction.jl index e4df1f1..3d77911 100644 --- a/src/ConformalPrediction/ConformalPrediction.jl +++ b/src/ConformalPrediction/ConformalPrediction.jl @@ -38,307 +38,6 @@ function get_names(X) return _names end -@doc raw""" - Plots.contourf(conf_model::ConformalModel,fitresult,X,y;kwargs...) - -A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of a fitted conformal classifier with exactly two input variable. Data (`X`,`y`) are plotted as dots and overlaid with predictions sets. `y` is used to indicate the ground-truth labels of samples by colour. Samples are visualized in a two-dimensional feature space, so it is expected that `X` ``\in \mathcal{R}^2``. By default, a contour is used to visualize the softmax output of the conformal classifier for the target label, where `target` indicates can be used to define the index of the target label. Transparent regions indicate that the prediction set does not include the `target` label. - -## Target - -In the binary case, `target` defaults to `2`, indexing the second label: assuming the labels are `[0,1]` then the softmax output for `1` is shown. In the multi-class cases, `target` defaults to the first class: for example, if the labels are `["🐶", "🐱", "🐭"]` (in that order) then the contour indicates the softmax output for `"🐶"`. - -## Set Size - -If `plot_set_size` is set to `true`, then the contour instead visualises the the set size. - -## Univariate and Higher Dimensional Inputs - -For univariate of multiple inputs (>2), this function is not applicable. See [`Plots.areaplot(conf_model::ConformalProbabilisticSet, fitresult, X, y; kwargs...)`](@ref) for an alternative way to visualize prediction for any conformal classifier. - -""" -function Plots.contourf( - conf_model::ConformalProbabilisticSet, - fitresult, - X, - y; - target::Union{Nothing,Real} = nothing, - ntest = 50, - zoom = -1, - xlims = nothing, - ylims = nothing, - plot_set_size = false, - plot_classification_loss = false, - plot_set_loss = false, - temp = 0.1, - κ = 0, - loss_matrix = UniformScaling(1.0), - kwargs..., -) - - # Setup: - X = permutedims(MLJBase.matrix(X)) - - @assert size(X, 1) == 2 "Can only create contour plot for conformal classifier with exactly two input variables." - - x1 = X[1, :] - x2 = X[2, :] - - # Plot limits: - xlims, ylims = generate_lims(x1, x2, xlims, ylims, zoom) - - # Surface range: - x1range = range(xlims[1]; stop = xlims[2], length = ntest) - x2range = range(ylims[1]; stop = ylims[2], length = ntest) - - # Target - if !isnothing(target) - @assert target in levels(y) "Specified target does not match any of the labels." - end - if length(unique(y)) > 1 - if isnothing(target) - @info "No target label supplied, using first." - end - target = isnothing(target) ? levels(y)[1] : target - if plot_set_size - _default_title = "Set size" - elseif plot_set_loss - _default_title = "Smooth set loss" - elseif plot_classification_loss - _default_title = "ℒ(C,$(target))" - else - _default_title = "p̂(y=$(target))" - end - else - if plot_set_size - _default_title = "Set size" - elseif plot_set_loss - _default_title = "Smooth set loss" - elseif plot_classification_loss - _default_title = "ℒ(C,$(target-1))" - else - _default_title = "p̂(y=$(target-1))" - end - end - title = !@isdefined(title) ? _default_title : title - - # Predictions - Z = [] - for x2 in x2range, x1 in x1range - p̂ = MLJBase.predict(conf_model, fitresult, table([x1 x2]))[1] - if plot_set_size - z = ismissing(p̂) ? 0 : sum(pdf.(p̂, p̂.decoder.classes) .> 0) - elseif plot_classification_loss - _target = categorical([target]; levels = levels(y)) - z = ConformalPrediction.ConformalTraining.classification_loss( - conf_model, - fitresult, - [x1 x2], - _target; - temp = temp, - loss_matrix = loss_matrix, - ) - elseif plot_set_loss - z = ConformalPrediction.ConformalTraining.smooth_size_loss( - conf_model, - fitresult, - [x1 x2]; - κ = κ, - temp = temp, - ) - else - z = ismissing(p̂) ? [missing for i = 1:length(levels(y))] : pdf.(p̂, levels(y)) - z = replace(z, 0 => missing) - end - push!(Z, z) - end - Z = reduce(hcat, Z) - Z = Z[findall(levels(y) .== target)[1][1], :] - - # Contour: - if plot_set_size - _n = length(unique(y)) - clim = (0, _n) - plt = contourf( - x1range, - x2range, - Z; - title = title, - xlims = xlims, - ylims = ylims, - c = cgrad(:blues, _n + 1; categorical = true), - clim = clim, - kwargs..., - ) - else - plt = contourf( - x1range, - x2range, - Z; - title = title, - xlims = xlims, - ylims = ylims, - c = cgrad(:blues), - linewidth = 0, - kwargs..., - ) - end - - # Samples: - y = typeof(y) <: CategoricalArrays.CategoricalArray ? y : Int.(y) - return scatter!(plt, x1, x2; group = y, kwargs...) -end - -""" - Plots.areaplot( - conf_model::ConformalProbabilisticSet, fitresult, X, y; - input_var::Union{Nothing,Int,Symbol}=nothing, - kwargs... - ) - -A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of any fitted conformal classifier. Using a stacked area chart, this function plots the softmax output(s) contained the the conformal predictions set on the vertical axis against an input variable `X` on the horizontal axis. In the case of multiple input variables, the `input_var` argument can be used to specify the desired input variable. -""" -function Plots.areaplot( - conf_model::ConformalProbabilisticSet, - fitresult, - X, - y; - input_var::Union{Nothing,Int,Symbol} = nothing, - kwargs..., -) - - # Setup: - Xraw = deepcopy(X) - _names = get_names(Xraw) - X = permutedims(MLJBase.matrix(X)) - - # Dimensions: - if size(X, 1) > 1 - if isnothing(input_var) - @info "Multiple inputs no input variable (`input_var`) specified: defaulting to first variable." - idx = 1 - else - if typeof(input_var) == Int - idx = input_var - else - @assert input_var ∈ _names "$(input_var) is not among the variable names of `X`." - idx = findall(_names .== input_var)[1] - end - end - x = X[idx, :] - else - idx = 1 - x = X - end - - # Predictions: - ŷ = MLJBase.predict(conf_model, fitresult, Xraw) - nout = length(levels(y)) - ŷ = - map(_y -> ismissing(_y) ? [0 for i = 1:nout] : pdf.(_y, levels(y)), ŷ) |> _y -> reduce(hcat, _y) - ŷ = permutedims(ŷ) - - return areaplot(x, ŷ; kwargs...) -end - -""" - Plots.plot( - conf_model::ConformalInterval, fitresult, X, y; - kwrgs... - ) - -A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of a fitted conformal regressor. Data (`X`,`y`) are plotted as dots and overlaid with predictions intervals. `y` is plotted on the vertical axis against a single variable `X` on the horizontal axis. A shaded area indicates the prediction interval. The line in the center of the interval is the midpoint of the interval and can be interpreted as the point estimate of the conformal regressor. In case `X` is multi-dimensional, `input_var` can be used to specify the input variable of interest that will be used for the horizontal axis. If unspecified, the first variable will be plotting by default. -""" -function Plots.plot( - conf_model::ConformalInterval, - fitresult, - X, - y; - input_var::Union{Nothing,Int,Symbol} = nothing, - xlims::Union{Nothing,Tuple} = nothing, - ylims::Union{Nothing,Tuple} = nothing, - zoom::Real = -0.5, - train_lab::Union{Nothing,String} = nothing, - test_lab::Union{Nothing,String} = nothing, - ymid_lw::Int = 1, - kwargs..., -) - - # Setup - title = !@isdefined(title) ? "" : title - train_lab = isnothing(train_lab) ? "Observed" : train_lab - test_lab = isnothing(test_lab) ? "Predicted" : test_lab - - Xraw = deepcopy(X) - _names = get_names(Xraw) - X = permutedims(MLJBase.matrix(X)) - - # Dimensions: - if size(X, 1) > 1 - if isnothing(input_var) - @info "Multivariate input for regression with no input variable (`input_var`) specified: defaulting to first variable." - idx = 1 - else - if typeof(input_var) == Int - idx = input_var - else - @assert input_var ∈ _names "$(input_var) is not among the variable names of `X`." - idx = findall(_names .== input_var)[1] - end - end - x = X[idx, :] - else - idx = 1 - x = X - end - - # Plot limits: - xlims, ylims = generate_lims(x, y, xlims, ylims, zoom) - - # Plot training data: - plt = scatter( - vec(x), - vec(y); - label = train_lab, - xlim = xlims, - ylim = ylims, - title = title, - kwargs..., - ) - - # Plot predictions: - ŷ = MLJBase.predict(conf_model, fitresult, Xraw) - lb, ub = eachcol(reduce(vcat, map(y -> permutedims(collect(y)), ŷ))) - ymid = (lb .+ ub) ./ 2 - yerror = (ub .- lb) ./ 2 - xplot = vec(x) - _idx = sortperm(xplot) - return plot!( - plt, - xplot[_idx], - ymid[_idx]; - label = test_lab, - ribbon = (yerror, yerror), - lw = ymid_lw, - kwargs..., - ) -end - -""" - Plots.bar(conf_model::ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...) - -A `Plots.jl` recipe/method extension that can be used to visualize the set size distribution of a conformal predictor. In the regression case, prediction interval widths are stratified into discrete bins. It can be useful to plot the distribution of set sizes in order to visually asses how adaptive a conformal predictor is. For more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs”. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should vary across the spectrum of ‘empty set’ to ‘all labels included’. -""" -function Plots.bar( - conf_model::ConformalModel, - fitresult, - X; - label = "", - xtickfontsize = 6, - kwrgs..., -) - ŷ = MLJBase.predict(conf_model, fitresult, X) - idx = ConformalPrediction.size_indicator(ŷ) - x = sort(levels(idx); lt = natural) - y = [sum(idx .== _x) for _x in x] - return Plots.bar(x, y; label = label, xtickfontsize = xtickfontsize, kwrgs...) -end +include("regression.jl") +include("bar.jl") +include("classification.jl") \ No newline at end of file diff --git a/src/ConformalPrediction/bar.jl b/src/ConformalPrediction/bar.jl new file mode 100644 index 0000000..8503187 --- /dev/null +++ b/src/ConformalPrediction/bar.jl @@ -0,0 +1,32 @@ +""" + plot( + conf_model::ConformalModel, + fitresult, + X + ) + +A `Plots.jl` recipe that can be used to visualize the set size distribution of a conformal predictor. In the regression case, prediction interval widths are stratified into discrete bins. It can be useful to plot the distribution of set sizes in order to visually asses how adaptive a conformal predictor is. For more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs”. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should vary across the spectrum of ‘empty set’ to ‘all labels included’. +""" +@recipe function plot( + conf_model::ConformalModel, + fitresult, + X +) + + # Plot attributes: + xtickfontsize --> 6 + + # Setup: + ŷ = MLJBase.predict(conf_model, fitresult, X) + idx = ConformalPrediction.size_indicator(ŷ) + x = sort(levels(idx); lt=natural) + y = [sum(idx .== _x) for _x in x] + + # Bar chart + @series begin + seriestype := :bar + label --> "" + x, y + end + +end \ No newline at end of file diff --git a/src/ConformalPrediction/classification.jl b/src/ConformalPrediction/classification.jl new file mode 100644 index 0000000..ad966f1 --- /dev/null +++ b/src/ConformalPrediction/classification.jl @@ -0,0 +1,229 @@ +@doc raw""" + plot( + conf_model::ConformalProbabilisticSet, + fitresult, + X, + y; + input_var=nothing, + target=nothing, + ntest=50, + zoom=-1, + plot_set_size=false, + plot_classification_loss=false, + plot_set_loss=false, + temp=0.1, + κ=0, + loss_matrix=UniformScaling(1.0), + ) + +A `Plots.jl` recipe that can be used to visualize the conformal predictions of a fitted conformal classifier. + +## Two Dimensional Inputs + +Data (`X`,`y`) are plotted as dots and overlaid with predictions sets. `y` is used to indicate the ground-truth labels of samples by colour. Samples are visualized in a two-dimensional feature space, so it is expected that `X` ``\in \mathcal{R}^2``. By default, a contour is used to visualize the softmax output of the conformal classifier for the target label, where `target` indicates can be used to define the index of the target label. Transparent regions indicate that the prediction set does not include the `target` label. + +### Target + +In the binary case, `target` defaults to `2`, indexing the second label: assuming the labels are `[0,1]` then the softmax output for `1` is shown. In the multi-class cases, `target` defaults to the first class: for example, if the labels are `["🐶", "🐱", "🐭"]` (in that order) then the contour indicates the softmax output for `"🐶"`. + +### Set Size + +If `plot_set_size` is set to `true`, then the contour instead visualises the the set size. + +## Univariate and Higher Dimensional Inputs + +In the case of univariate inputs or higher dimensional inputs, a stacked area plot is created: in particular, this method plots the softmax output(s) contained the the conformal predictions set on the vertical axis against an input variable `X` on the horizontal axis. In the case of multiple input variables, the `input_var` argument can be used to specify the desired input variable. + +""" +@recipe function plot( + conf_model::ConformalProbabilisticSet, + fitresult, + X, + y; + input_var=nothing, + target=nothing, + ntest=50, + zoom=-1, + plot_set_size=false, + plot_classification_loss=false, + plot_set_loss=false, + temp=0.1, + κ=0, + loss_matrix=UniformScaling(1.0), +) + + # Get user-defined arguments: + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + + if size(permutedims(MLJBase.matrix(X)), 1) > 2 + + # AREA PLOT FOR MULTI-D + + # Plot attributes: + xtickfontsize --> 6 + + # Setup: + Xraw = deepcopy(X) + _names = get_names(Xraw) + X = permutedims(MLJBase.matrix(X)) + + # Dimensions: + if size(X, 1) > 1 + if isnothing(input_var) + @info "Multiple inputs no input variable (`input_var`) specified: defaulting to first variable." + idx = 1 + else + if typeof(input_var) == Int + idx = input_var + else + @assert input_var ∈ _names "$(input_var) is not among the variable names of `X`." + idx = findall(_names .== input_var)[1] + end + end + x = X[idx, :] + else + idx = 1 + x = X + end + + # Predictions: + ŷ = MLJBase.predict(conf_model, fitresult, Xraw) + nout = length(levels(y)) + ŷ = + map(_y -> ismissing(_y) ? [0 for i = 1:nout] : pdf.(_y, levels(y)), ŷ) |> _y -> reduce(hcat, _y) + ŷ = permutedims(ŷ) + println(x) + println(ŷ[sortperm(x), :]) + + # Area chart + args = (x, ŷ) + data = cumsum(args[end], dims=2) + x = length(args) == 1 ? (axes(data, 1)) : args[1] + seriestype := :line + for i in axes(data, 2) + @series begin + fillrange := i > 1 ? data[:, i-1] : 0 + x, data[:, i] + end + end + + else + + # CONTOUR PLOT FOR 2D + + # Setup: + x1, x2, x1range, x2range, Z, xlims, ylims, _default_title = setup_contour_cp( + conf_model, fitresult, X, y, xlims, ylims, zoom, ntest, target, + plot_set_size, plot_classification_loss, plot_set_loss, temp, κ, loss_matrix, + ) + + # Contour: + _n = length(unique(y)) + clim = (0, _n) + @series begin + seriestype := :contourf + x1range, x2range, Z + end + + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + x1[group_idx], x2[group_idx] + end + end + + end + +end + +function setup_contour_cp( + conf_model, fitresult, X, y, xlims, ylims, zoom, ntest, target, + plot_set_size, + plot_classification_loss, + plot_set_loss, + temp, + κ, + loss_matrix, +) + X = permutedims(MLJBase.matrix(X)) + + x1 = X[1, :] + x2 = X[2, :] + + # Plot limits: + xlims, ylims = generate_lims(x1, x2, xlims, ylims, zoom) + + # Surface range: + x1range = range(xlims[1]; stop=xlims[2], length=ntest) + x2range = range(ylims[1]; stop=ylims[2], length=ntest) + + # Target + if !isnothing(target) + @assert target in levels(y) "Specified target does not match any of the labels." + end + if length(unique(y)) > 1 + if isnothing(target) + @info "No target label supplied, using first." + end + target = isnothing(target) ? levels(y)[1] : target + if plot_set_size + _default_title = "Set size" + elseif plot_set_loss + _default_title = "Smooth set loss" + elseif plot_classification_loss + _default_title = "ℒ(C,$(target))" + else + _default_title = "p̂(y=$(target))" + end + else + if plot_set_size + _default_title = "Set size" + elseif plot_set_loss + _default_title = "Smooth set loss" + elseif plot_classification_loss + _default_title = "ℒ(C,$(target-1))" + else + _default_title = "p̂(y=$(target-1))" + end + end + + # Predictions + Z = [] + for x2 in x2range, x1 in x1range + p̂ = MLJBase.predict(conf_model, fitresult, table([x1 x2]))[1] + if plot_set_size + z = ismissing(p̂) ? 0 : sum(pdf.(p̂, p̂.decoder.classes) .> 0) + elseif plot_classification_loss + _target = categorical([target]; levels=levels(y)) + z = ConformalPrediction.ConformalTraining.classification_loss( + conf_model, + fitresult, + [x1 x2], + _target; + temp=temp, + loss_matrix=loss_matrix, + ) + elseif plot_set_loss + z = ConformalPrediction.ConformalTraining.smooth_size_loss( + conf_model, + fitresult, + [x1 x2]; + κ=κ, + temp=temp, + ) + else + z = ismissing(p̂) ? [missing for i = 1:length(levels(y))] : pdf.(p̂, levels(y)) + z = replace(z, 0 => missing) + end + push!(Z, z) + end + Z = reduce(hcat, Z) + Z = Z[findall(levels(y) .== target)[1][1], :] + + return x1, x2, x1range, x2range, Z, xlims, ylims, _default_title +end \ No newline at end of file diff --git a/src/ConformalPrediction/regression.jl b/src/ConformalPrediction/regression.jl new file mode 100644 index 0000000..4734026 --- /dev/null +++ b/src/ConformalPrediction/regression.jl @@ -0,0 +1,91 @@ +""" + plot( + conf_model::ConformalInterval, + fitresult, + X, + y; + input_var=nothing, + zoom=-0.5, + train_lab=nothing, + test_lab=nothing, + ) + +A `Plots.jl` recipe that can be used to visualize the conformal predictions of a fitted conformal regressor. Data (`X`,`y`) are plotted as dots and overlaid with predictions intervals. `y` is plotted on the vertical axis against a single variable `X` on the horizontal axis. A shaded area indicates the prediction interval. The line in the center of the interval is the midpoint of the interval and can be interpreted as the point estimate of the conformal regressor. In case `X` is multi-dimensional, `input_var` can be used to specify the input variable of interest that will be used for the horizontal axis. If unspecified, the first variable will be plotting by default. +""" +@recipe function plot( + conf_model::ConformalInterval, + fitresult, + X, + y; + input_var=nothing, + zoom=-0.5, + train_lab=nothing, + test_lab=nothing, +) + + # Get user-defined arguments: + train_lab = isnothing(train_lab) ? "Observed" : train_lab + test_lab = isnothing(test_lab) ? "Predicted" : test_lab + title = get(plotattributes, :xlims, "") + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + + # Plot attributes: + linewidth --> 1 + + # Setup: + x, y, xlims, ylims, Xraw = setup_ci(X, y, input_var, xlims, ylims, zoom) + + # Plot predictions: + ŷ = MLJBase.predict(conf_model, fitresult, Xraw) + lb, ub = eachcol(reduce(vcat, map(y -> permutedims(collect(y)), ŷ))) + ymid = (lb .+ ub) ./ 2 + yerror = (ub .- lb) ./ 2 + xplot = vec(x) + _idx = sortperm(xplot) + @series begin + seriestype := :path + ribbon := (yerror, yerror) + label := test_lab + xplot[_idx], ymid[_idx] + end + + # Scatter observed data: + @series begin + seriestype := :scatter + label := train_lab + vec(x), vec(y) + end + +end + +function setup_ci(X, y, input_var, xlims, ylims, zoom) + + Xraw = deepcopy(X) + _names = get_names(Xraw) + X = permutedims(MLJBase.matrix(X)) + + # Dimensions: + if size(X, 1) > 1 + if isnothing(input_var) + @info "Multivariate input for regression with no input variable (`input_var`) specified: defaulting to first variable." + idx = 1 + else + if typeof(input_var) == Int + idx = input_var + else + @assert input_var ∈ _names "$(input_var) is not among the variable names of `X`." + idx = findall(_names .== input_var)[1] + end + end + x = X[idx, :] + else + idx = 1 + x = X + end + + # Plot limits: + xlims, ylims = generate_lims(x, y, xlims, ylims, zoom) + + return x, y, xlims, ylims, Xraw +end \ No newline at end of file diff --git a/src/CounterfactualExplations/counterfactuals.jl b/src/CounterfactualExplations/counterfactuals.jl index f5a705f..a2d2390 100644 --- a/src/CounterfactualExplations/counterfactuals.jl +++ b/src/CounterfactualExplations/counterfactuals.jl @@ -1,38 +1,89 @@ -using MLUtils: stack - """ - Plots.plot( + plot( ce::CounterfactualExplanation; - alpha_ = 0.5, - plot_up_to::Union{Nothing,Int} = nothing, - plot_proba::Bool = false, - kwargs..., + target=nothing, + length_out=100, + zoom=-0.1, + dim_red=:pca, + plot_loss=false, + loss_fun=nothing, + plot_up_to = nothing, + n_points = nothing, ) -Calling `plot` on an instance of type `CounterfactualExplanation` returns a plot that visualises the entire counterfactual path. For multi-dimensional input data, the data is first compressed into two dimensions. The decision boundary is then approximated using using a Nearest Neighbour classifier. This is still somewhat experimental at the moment. +Calling `Plots.plot` on a `CounterfactualExplanation` object will plot the training data (scatters), model predictions for the specified `target` (contour) and the counterfactual path (scatter). +""" +@recipe function plot( + ce::CounterfactualExplanation; + target=nothing, + length_out=100, + zoom=-0.1, + dim_red=:pca, + plot_loss=false, + loss_fun=nothing, + plot_up_to = nothing, + n_points = nothing, +) + if !isnothing(n_points) + if n_points < size(ce.data.X, 2) + @info "Undersampling to $(n_points) points." + else + @info "Oversampling to $(n_points) points." + end + xlims, ylims = extrema(ce.data.X[1, :]), extrema(ce.data.X[2, :]) + ce = deepcopy(ce) + ce.data = DataPreprocessing.subsample(ce.data, n_points) + else + xlims, ylims = nothing, nothing + end -# Examples + # Asserts + @assert !plot_loss || !isnothing(loss_fun) "Need to provide a loss function to plot the loss, e.g. (`loss_fun=Flux.Losses.logitcrossentropy`)." + + # Get user-defined arguments: + xlims = get(plotattributes, :xlims, xlims) + ylims = get(plotattributes, :ylims, ylims) + ms = get(plotattributes, :markersize, 3) + mspath = ms*2 + msfinal = mspath*2 + + # Plot attributes + linewidth --> 0.1 + + contour_series, X, y, xlims, ylims = setup_model_plot( + ce.M, + ce.data, + target, + length_out, + zoom, + dim_red, + plot_loss, + loss_fun, + xlims, + ylims, + ) -```julia-repl -# Search: -generator = GenericGenerator() -ce = generate_counterfactual(x, target, counterfactual_data, M, generator) + xlims --> xlims + ylims --> ylims -plot(ce) -``` -""" -function Plots.plot( - ce_plot::CounterfactualExplanation; - alpha_ = 0.5, - plot_up_to::Union{Nothing,Int} = nothing, - plot_proba::Bool = false, - n_points = 1000, - kwargs..., -) + # Contour plot: + @series begin + seriestype := :contourf + contour_series[1], contour_series[2], contour_series[3] + end - ce = deepcopy(ce_plot) - ce.data = DataPreprocessing.subsample(ce.data, n_points) + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + markersize := ms + group_idx = findall(y .== x) + label --> "$(x)" + X[group_idx, 1], X[group_idx, 2] + end + end max_iter = total_steps(ce) max_iter = if isnothing(plot_up_to) @@ -41,20 +92,26 @@ function Plots.plot( minimum([plot_up_to, max_iter]) end max_iter += 1 - ingredients = set_up_plots(ce; alpha = alpha_, plot_proba = plot_proba, kwargs...) - - for t = 1:max_iter - final_state = t == max_iter - plot_state(ce, t, final_state; ingredients...) - end - - plt = if plot_proba - Plots.plot(ingredients.p1, ingredients.p2; kwargs...) - else - Plots.plot(ingredients.p1; kwargs...) + path_x, path_y = setup_ce_plot(ce) + + # Outer loop over number of counterfactuals: + for (num_counterfactual, X) in enumerate(eachslice(path_x, dims=3)) + # Inner loop over counterfactual search steps: + steps = zip(eachcol(X), path_y) + for (i,(x,y)) in enumerate(steps) + i <= max_iter || break + _final_iter = i == length(steps) || i == max_iter + _annotate = i == length(steps) && ce.num_counterfactuals > 1 + @series begin + seriestype := :scatter + markercolor := CategoricalArrays.levelcode.(y[num_counterfactual]) + markersize := _final_iter ? msfinal : mspath + series_annotation := _annotate ? text("C$(num_counterfactual)", mspath) : nothing + label := :none + x[1,:], x[2,:] + end + end end - - return plt end """ @@ -75,11 +132,11 @@ animate_path(ce) function animate_path( ce::CounterfactualExplanation, path = tempdir(); - alpha_ = 0.5, plot_up_to::Union{Nothing,Int} = nothing, - plot_proba::Bool = false, - kwargs..., + legend = :topright, + kwrgs..., ) + max_iter = total_steps(ce) max_iter = if isnothing(plot_up_to) total_steps(ce) @@ -87,101 +144,22 @@ function animate_path( minimum([plot_up_to, max_iter]) end max_iter += 1 - ingredients = set_up_plots(ce; alpha = alpha_, plot_proba = plot_proba, kwargs...) anim = @animate for t = 1:max_iter - final_state = t == max_iter - plot_state(ce, t, final_state; ingredients...) - if plot_proba - plot(ingredients.p1, ingredients.p2; kwargs...) - else - plot(ingredients.p1; kwargs...) - end + plot(ce; plot_up_to=t, legend=legend, kwrgs...) end return anim end """ - plot_state( - ce::CounterfactualExplanation, - t::Int, - final_state::Bool; - kwargs... - ) - -Helper function that plots a single step of the counterfactual path. -""" -function plot_state(ce::CounterfactualExplanation, t::Int, final_state::Bool; kwargs...) - args = PlotIngredients(; kwargs...) - x1 = args.path_embedded[1, t, :] - x2 = args.path_embedded[2, t, :] - y = args.path_labels[t] - _c = CategoricalArrays.levelcode.(y) - n_ = ce.num_counterfactuals - label_ = reshape(["C$i" for i = 1:n_], 1, n_) - if !final_state - scatter!(args.p1, x1, x2; group = y, colour = _c, ms = 5, label = "") - else - scatter!(args.p1, x1, x2; group = y, colour = _c, ms = 10, label = "") - if n_ > 1 - label_1 = vec([text(lab, 5) for lab in label_]) - annotate!(x1, x2, label_1) - end - end - if args.plot_proba - probs_ = reshape(reduce(vcat, args.path_probs[1:t]), t, n_) - if t == 1 && n_ > 1 - label_2 = label_ - else - label_2 = "" - end - plot!( - args.p2, - probs_; - label = label_2, - color = reshape(1:n_, 1, n_), - title = "p(y=$(ce.target))", - ) - end -end - -"A container used for plotting." -Base.@kwdef struct PlotIngredients - p1::Any - p2::Any - path_embedded::Any - path_labels::Any - path_probs::Any - alpha::Any - plot_proba::Any -end - -""" - set_up_plots( - ce::CounterfactualExplanation; - alpha, - plot_proba, - kwargs... - ) + setup_ce_plot(ce::CounterfactualExplanation) A helper method that prepares data for plotting. """ -function set_up_plots(ce::CounterfactualExplanation; alpha, plot_proba, kwargs...) - p1 = plot(ce.M, ce.data; target = ce.target, alpha = alpha, kwargs...) - p2 = plot(; xlims = (1, total_steps(ce) + 1), ylims = (0, 1)) +function setup_ce_plot(ce::CounterfactualExplanation) path_embedded = embed_path(ce) path_labels = CounterfactualExplanations.counterfactual_label_path(ce) y_levels = ce.data.y_levels - path_labels = map(x -> CategoricalArrays.categorical(x; levels = y_levels), path_labels) - path_probs = CounterfactualExplanations.target_probs_path(ce) - output = ( - p1 = p1, - p2 = p2, - path_embedded = path_embedded, - path_labels = path_labels, - path_probs = path_probs, - alpha = alpha, - plot_proba = plot_proba, - ) - return output + path_labels = map(x -> CategoricalArrays.categorical(x; levels=y_levels), path_labels) + return path_embedded, path_labels end diff --git a/src/CounterfactualExplations/data.jl b/src/CounterfactualExplations/data.jl index fabc294..c9b030d 100644 --- a/src/CounterfactualExplations/data.jl +++ b/src/CounterfactualExplations/data.jl @@ -12,9 +12,9 @@ function embed(data::CounterfactualData, X::AbstractArray = nothing; dim_red::Sy else @info "Training model to compress data." if dim_red == :pca - tfn = MultivariateStats.fit(PCA, X_train; maxoutdim=2) + tfn = MultivariateStats.fit(PCA, X_train; maxoutdim = 2) elseif dim_red == :tsne - tfn = MultivariateStats.fit(TSNE, X_train; maxoutdim=2) + tfn = MultivariateStats.fit(TSNE, X_train; maxoutdim = 2) end data.input_encoder = nothing X = isnothing(X) ? X_train : X @@ -22,7 +22,7 @@ function embed(data::CounterfactualData, X::AbstractArray = nothing; dim_red::Sy end # Transforming: - X = typeof(X) <: Vector{<:Matrix} ? MLUtils.stack(X, dims=2) : X + X = typeof(X) <: Vector{<:Matrix} ? MLUtils.stack(X, dims = 2) : X if !isnothing(tfn) && !isnothing(X) X = mapslices(x -> MultivariateStats.predict(tfn, x), X, dims = 1) else @@ -53,8 +53,24 @@ function prepare_for_plotting(data::CounterfactualData; dim_red::Symbol = :pca) return X', y, multi_dim end -function Plots.scatter!(data::CounterfactualData; dim_red::Symbol = :pca, kwargs...) +""" + plot(data::CounterfactualData; dim_red = :pca) + +Calling `Plots.plot` on a `data::CounterfactualData` object will generate a scatter plot of the data. +""" +@recipe function plot(data::CounterfactualData; dim_red = :pca) + + # Set up: X, y, _ = prepare_for_plotting(data; dim_red = dim_red) - _c = Int.(y.refs) - return Plots.scatter!(X[:, 1], X[:, 2]; group = y, colour = _c, kwargs...) + + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + X[group_idx, 1], X[group_idx, 2] + end + end end diff --git a/src/CounterfactualExplations/models.jl b/src/CounterfactualExplations/models.jl index 19fbf16..707de47 100644 --- a/src/CounterfactualExplations/models.jl +++ b/src/CounterfactualExplations/models.jl @@ -2,28 +2,92 @@ using DataAPI using Distributions: pdf using NearestNeighborModels: KNNClassifier -function Plots.plot( +""" + function plot( + M::AbstractFittedModel, + data::CounterfactualData; + target = nothing, + length_out = 100, + zoom = -0.1, + dim_red = :pca, + plot_loss = false, + loss_fun = nothing, + ) + +Calling `Plots.plot` on a `AbstractFittedModel` will plot the model's predictions as a contour. The `target` argument can be used to plot the predictions for a specific target variable. The `length_out` argument can be used to control the number of points used to plot the contour. The `zoom` argument can be used to control the zoom of the plot. The `dim_red` argument can be used to control the method used to reduce the dimensionality of the data if it has more than two features. +""" +@recipe function plot( M::AbstractFittedModel, - data::DataPreprocessing.CounterfactualData; - target::Union{Nothing,RawTargetType} = nothing, - colorbar = true, - title = "", + data::CounterfactualData; + target = nothing, length_out = 100, zoom = -0.1, - xlims = nothing, - ylims = nothing, - linewidth = 0.1, - alpha = 1.0, - contour_alpha = 1.0, - dim_red::Symbol = :pca, - plot_loss::Bool = false, - loss_fun::Union{Nothing,Function} = nothing, - kwargs..., + dim_red = :pca, + plot_loss = false, + loss_fun = nothing, +) + + # Asserts + @assert !plot_loss || !isnothing(loss_fun) "Need to provide a loss function to plot the loss, e.g. (`loss_fun=Flux.Losses.logitcrossentropy`)." + + # Get user-defined arguments: + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + + # Plot attributes + linewidth --> 0.1 + + contour_series, X, y, xlims, ylims = setup_model_plot( + M, + data, + target, + length_out, + zoom, + dim_red, + plot_loss, + loss_fun, + xlims, + ylims, + ) + + xlims --> xlims + ylims --> ylims + + # Contour plot: + @series begin + seriestype := :contourf + contour_series[1], contour_series[2], contour_series[3] + end + + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + X[group_idx, 1], X[group_idx, 2] + end + end + +end + +function setup_model_plot( + M::AbstractFittedModel, + data::CounterfactualData, + target, + length_out, + zoom, + dim_red, + plot_loss, + loss_fun, + xlims, + ylims, ) X, _ = DataPreprocessing.unpack_data(data) ŷ = probs(M, X) # true predictions if size(ŷ, 1) > 1 - ŷ = vec(Flux.onecold(ŷ, 1:size(ŷ, 1))) + ŷ = vec(OneHotArrays.onecold(ŷ, 1:size(ŷ, 1))) else ŷ = vec(ŷ) end @@ -44,6 +108,7 @@ function Plots.plot( else xlims = xlims .+ (zoom, -zoom) end + if isnothing(ylims) ylims = (minimum(X[:, 2]), maximum(X[:, 2])) .+ (zoom, -zoom) else @@ -54,9 +119,11 @@ function Plots.plot( plot_loss = plot_loss || !isnothing(loss_fun) - if plot_loss + if plot_loss # Loss surface: - Z = [loss_fun(logits(M, [x, y][:, :]), target_encoded) for x in x_range, y in y_range] + Z = [ + loss_fun(logits(M, [x, y][:, :]), target_encoded) for x in x_range, y in y_range + ] else # Prediction surface: if multi_dim @@ -84,22 +151,10 @@ function Plots.plot( target_idx = get_target_index(data.y_levels, target) z = plot_loss ? Z[1, :] : Z[target_idx, :] - # Contour: - Plots.contourf( - x_range, - y_range, - z; - colorbar = colorbar, - title = title, - linewidth = linewidth, - xlims = xlims, - ylims = ylims, - kwargs..., - alpha = contour_alpha, - ) + # Collect: + contour_series = (x_range, y_range, z) - # Samples: - return Plots.scatter!(data; dim_red = dim_red, alpha = alpha, kwargs...) + return contour_series, X, y, xlims, ylims end function voronoi(X::AbstractMatrix, y::AbstractVector) diff --git a/src/LaplaceRedux/LaplaceRedux.jl b/src/LaplaceRedux/LaplaceRedux.jl index b8fa737..5120669 100644 --- a/src/LaplaceRedux/LaplaceRedux.jl +++ b/src/LaplaceRedux/LaplaceRedux.jl @@ -1,142 +1,184 @@ using LaplaceRedux using Trapz -function Plots.plot( +""" + plot( + la::Laplace, + X::AbstractArray, + y::AbstractArray; + link_approx=:probit, + target=nothing, + length_out=50, + zoom=-1, + ) + +Calling `Plots.plot` on a `Laplace` object will plot the posterior predictive distribution and the training data. +""" +@recipe function plot( la::Laplace, X::AbstractArray, y::AbstractArray; - link_approx::Symbol = :probit, - target::Union{Nothing,Real} = nothing, - colorbar = true, - title = nothing, - length_out = 50, - zoom = -1, - xlims = nothing, - ylims = nothing, - linewidth = 0.1, - lw = 4, - kwargs..., + link_approx=:probit, + target=nothing, + length_out=50, + zoom=-1, ) + + # Asserts: if la.likelihood == :regression @assert size(X, 1) == 1 "Cannot plot regression for multiple input variables." else @assert size(X, 1) == 2 "Cannot plot classification for more than two input variables." end + # Get user-defined arguments: + xlims = get(plotattributes, :xlims, nothing) + ylims = get(plotattributes, :ylims, nothing) + title = get(plotattributes, :title, nothing) + + # Plot attributes + lw = get(plotattributes, :linewidth, 1) + lw_yhat = lw*2 + lw_contour = lw*0.1 + if la.likelihood == :regression - # REGRESSION + xrange, yrange, xlims, ylims = surface_range(X, y, xlims, ylims, zoom, length_out) + xlims := xlims + ylims := ylims - # Surface range: - if isnothing(xlims) - xlims = (minimum(X), maximum(X)) .+ (zoom, -zoom) - else - xlims = xlims .+ (zoom, -zoom) - end - if isnothing(ylims) - ylims = (minimum(y), maximum(y)) .+ (zoom, -zoom) - else - ylims = ylims .+ (zoom, -zoom) - end - x_range = range(xlims[1]; stop = xlims[2], length = length_out) - y_range = range(ylims[1]; stop = ylims[2], length = length_out) - - title = isnothing(title) ? "" : title - - # Plot: - scatter( - vec(X), - vec(y); - label = "ytrain", - xlim = xlims, - ylim = ylims, - lw = lw, - title = title, - kwargs..., - ) - _x = collect(x_range)[:, :]' - normal_distr, fμ, fvar = LaplaceRedux.glm_predictive_distribution(la, _x) + # Plot predictions: + _x = collect(xrange)[:, :]' + fμ, fvar = LaplaceRedux.predict(la, _x) fμ = vec(fμ) fσ = vec(sqrt.(fvar)) - pred_std = sqrt.(fσ .^ 2 .+ la.prior.σ^2) - plot!( - x_range, - fμ; - color = 2, - label = "yhat", - ribbon = (1.96 * pred_std, 1.96 * pred_std), - lw = lw, - kwargs..., - ) # the specific values 1.96 are used here to create a 95% confidence interval - else - - # CLASSIFICATION - - # Surface range: - if isnothing(xlims) - xlims = (minimum(X[1, :]), maximum(X[1, :])) .+ (zoom, -zoom) - else - xlims = xlims .+ (zoom, -zoom) + @series begin + seriestype := :path + ribbon := (1.96 * fσ, 1.96 * fσ) + linewidth := lw_yhat + label --> "yhat" + xrange, fμ end - if isnothing(ylims) - ylims = (minimum(X[2, :]), maximum(X[2, :])) .+ (zoom, -zoom) - else - ylims = ylims .+ (zoom, -zoom) + + # Scatter training data: + @series begin + seriestype := :scatter + label --> "ytrain" + vec(X), vec(y) end - x_range = range(xlims[1]; stop = xlims[2], length = length_out) - y_range = range(ylims[1]; stop = ylims[2], length = length_out) - - # Plot - predict_ = function (X::AbstractVector) - z = LaplaceRedux.predict(la,X; link_approx = link_approx) - if LaplaceRedux.outdim(la) == 1 # binary - z = [1.0 - z[1], z[1]] - end - return z + + end + + if la.likelihood == :classification + + xrange, yrange, xlims, ylims = surface_range(X, xlims, ylims, zoom, length_out) + xlims := xlims + ylims := ylims + + Z, target, title = get_contour(la, xrange, yrange, link_approx, target, title) + + # Contour plot: + @series begin + seriestype := :contourf + linewidth := lw_contour + title --> title + xrange, yrange, Z[Int(target), :] end - Z = [predict_([x, y]) for x in x_range, y in y_range] - Z = reduce(hcat, Z) - if LaplaceRedux.outdim(la) > 1 - if isnothing(target) - @info "No target label supplied, using first." + + # Scatter plot: + for (i, x) in enumerate(unique(sort(y))) + @series begin + seriestype := :scatter + markercolor := i + group_idx = findall(y .== x) + label --> "$(x)" + X[1, group_idx], X[2, group_idx] end - target = isnothing(target) ? 1 : target - title = isnothing(title) ? "p̂(y=$(target))" : title - else - target = isnothing(target) ? 2 : target - title = isnothing(title) ? "p̂(y=$(target-1))" : title end - # Contour: - contourf( - x_range, - y_range, - Z[Int(target), :]; - colorbar = colorbar, - title = title, - linewidth = linewidth, - xlims = xlims, - ylims = ylims, - kwargs..., - ) - # Samples: - scatter!(X[1, :], X[2, :]; group = Int.(y), color = Int.(y), kwargs...) end + +end + +function surface_range( + X::AbstractArray, y::AbstractArray, + xlims,ylims,zoom,length_out, +) + + # Surface range: + if isnothing(xlims) + xlims = (minimum(X), maximum(X)) .+ (zoom, -zoom) + else + xlims = xlims .+ (zoom, -zoom) + end + if isnothing(ylims) + ylims = (minimum(y), maximum(y)) .+ (zoom, -zoom) + else + ylims = ylims .+ (zoom, -zoom) + end + x_range = range(xlims[1]; stop = xlims[2], length = length_out) + y_range = range(ylims[1]; stop = ylims[2], length = length_out) + return x_range, y_range, xlims, ylims + +end + +function surface_range(X::AbstractArray,xlims,ylims,zoom,length_out) + + if isnothing(xlims) + xlims = (minimum(X[1, :]), maximum(X[1, :])) .+ (zoom, -zoom) + else + xlims = xlims .+ (zoom, -zoom) + end + if isnothing(ylims) + ylims = (minimum(X[2, :]), maximum(X[2, :])) .+ (zoom, -zoom) + else + ylims = ylims .+ (zoom, -zoom) + end + x_range = range(xlims[1]; stop = xlims[2], length = length_out) + y_range = range(ylims[1]; stop = ylims[2], length = length_out) + + return x_range, y_range, xlims, ylims +end + +function get_contour(la::Laplace, x_range, y_range, link_approx, target, title) + + predict_ = function (la, X::AbstractVector) + z = LaplaceRedux.predict(la, X; link_approx = link_approx) + if LaplaceRedux.outdim(la) == 1 # binary + z = [1.0 - z[1], z[1]] + end + return z + end + Z = [predict_(la, [x, y]) for x in x_range, y in y_range] + Z = reduce(hcat, Z) + if LaplaceRedux.outdim(la) > 1 + if isnothing(target) + @info "No target label supplied, using first." + end + target = isnothing(target) ? 1 : target + title = isnothing(title) ? "p̂(y=$(target))" : title + else + target = isnothing(target) ? 2 : target + title = isnothing(title) ? "p̂(y=$(target-1))" : title + end + + return Z, target, title end """ -'Calibration_Plot_Regression(y_cal, samp_distr, n_bins)' + calibration_plot(y_cal, samp_distr, n_bins) This plot displays the true frequency of points in each confidence interval relative to the predicted fraction of points in that interval. The intervals are taken in step of 0.05 quantiles. -Input: --'la::Laplace': the laplace model to use. --'Y_cal': a vector of true values y_t. --'samp_distr': an array of sampled distributions F(x_t) corresponding to the y_t stacked column-wise. --'n_bins': numbers of bins to use. +## Inputs + +- `la::Laplace` -- the laplace model to use. +- `Y_cal` -- a vector of true values y_t. +- `samp_distr` -- an array of sampled distributions F(x_t) corresponding to the y_t stacked column-wise. +- `n_bins` -- numbers of bins to use. """ -function Calibration_Plot(la::Laplace, y_cal, samp_distr; n_bins = 20) +function calibration_plot(la::Laplace, y_cal, samp_distr; n_bins = 20) quantiles = collect(range(0; stop = 1, length = n_bins + 1)) # Create a new plot object p = plot() diff --git a/src/TaijaPlotting.jl b/src/TaijaPlotting.jl index e37a519..722571a 100644 --- a/src/TaijaPlotting.jl +++ b/src/TaijaPlotting.jl @@ -2,16 +2,17 @@ module TaijaPlotting using CategoricalArrays using CounterfactualExplanations -using Flux using MLJBase using MultivariateStats +using OneHotArrays using Plots +using RecipesBase export animate_path include("ConformalPrediction/ConformalPrediction.jl") include("CounterfactualExplations/CounterfactualExplanations.jl") include("LaplaceRedux/LaplaceRedux.jl") -export Calibration_Plot +export calibration_plot end diff --git a/test/ConformalPrediction.jl b/test/ConformalPrediction.jl index c0e1e1f..c80a947 100644 --- a/test/ConformalPrediction.jl +++ b/test/ConformalPrediction.jl @@ -21,9 +21,9 @@ isplot(plt) = typeof(plt) <: Plots.Plot fit!(mach, rows = train) @test isplot(bar(mach.model, mach.fitresult, X)) - @test isplot(areaplot(mach.model, mach.fitresult, X, y)) - @test isplot(areaplot(mach.model, mach.fitresult, X, y; input_var = 1)) - @test isplot(areaplot(mach.model, mach.fitresult, X, y; input_var = :x1)) + @test isplot(plot(mach.model, mach.fitresult, X, y)) + @test isplot(plot(mach.model, mach.fitresult, X, y; input_var = 1)) + @test isplot(plot(mach.model, mach.fitresult, X, y; input_var = :x1)) @test isplot(contourf(mach.model, mach.fitresult, X, y)) @test isplot( contourf(mach.model, mach.fitresult, X, y; zoom = -1, plot_set_size = true), diff --git a/test/Project.toml b/test/Project.toml index b197002..b86b174 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -7,10 +8,12 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" TaijaData = "9d524318-b4e6-4a65-86d2-b2b72d07866c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +Aqua = "0.8" ConformalPrediction = "0.1" CounterfactualExplanations = "1.1.5" Flux = "0.12, 0.13, 0.14" diff --git a/test/aqua.jl b/test/aqua.jl new file mode 100644 index 0000000..8333b33 --- /dev/null +++ b/test/aqua.jl @@ -0,0 +1,9 @@ +using Aqua +using RecipesBase + +@testset "Aqua.jl" begin + # Ambiguities needs to be tested seperately until the bug in Aqua package (https://github.com/JuliaTesting/Aqua.jl/issues/77) is fixed + Aqua.test_ambiguities([TaijaPlotting]; recursive=false, broken=false) + + Aqua.test_all(TaijaPlotting; ambiguities=false, piracies=false) +end diff --git a/test/runtests.jl b/test/runtests.jl index 98b0638..e5be431 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using TaijaPlotting using Test @testset "TaijaPlotting.jl" begin + include("aqua.jl") include("ConformalPrediction.jl") include("CounterfactualExplanations.jl") include("LaplaceRedux.jl")