Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

32 move to plot recipes instead of overloading #33

Merged
merged 13 commits into from
Sep 5, 2024
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "blue"
13 changes: 9 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,45 +1,50 @@
name = "TaijaPlotting"
uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240"
authors = ["Patrick Altmeyer"]
version = "1.2.0"
version = "1.3.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Trapz = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1"

[compat]
Aqua = "0.8"
CategoricalArrays = "0.10"
ConformalPrediction = "0.1, 1"
CounterfactualExplanations = "1.1.5"
DataAPI = "1"
Distributions = "0.25"
Flux = "0.12, 0.13, 0.14"
LaplaceRedux = "0.1, 0.2, 1.0.1"
LinearAlgebra = "1.10"
MLJBase = "0.21, 0.22, 1"
MLUtils = "0.4"
MultivariateStats = "0.9, 0.10"
NaturalSort = "1"
NearestNeighborModels = "0.2"
OneHotArrays = "0.2.5"
Plots = "1"
RecipesBase = "1.3.4"
Test = "1.10"
Trapz = "2.0.3"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Aqua", "Test"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
[![Build Status](https://github.com/JuliaTrustworthyAI/TaijaPlotting.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaTrustworthyAI/TaijaPlotting.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/JuliaTrustworthyAI/TaijaPlotting.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaTrustworthyAI/TaijaPlotting.jl)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

A package for plotting custom symbols from Taija packages.
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
TaijaPlotting = "bd7198b4-c7d6-400c-9bab-9a24614b0240"
307 changes: 3 additions & 304 deletions src/ConformalPrediction/ConformalPrediction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,307 +38,6 @@ function get_names(X)
return _names
end

@doc raw"""
Plots.contourf(conf_model::ConformalModel,fitresult,X,y;kwargs...)

A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of a fitted conformal classifier with exactly two input variable. Data (`X`,`y`) are plotted as dots and overlaid with predictions sets. `y` is used to indicate the ground-truth labels of samples by colour. Samples are visualized in a two-dimensional feature space, so it is expected that `X` ``\in \mathcal{R}^2``. By default, a contour is used to visualize the softmax output of the conformal classifier for the target label, where `target` indicates can be used to define the index of the target label. Transparent regions indicate that the prediction set does not include the `target` label.

## Target

In the binary case, `target` defaults to `2`, indexing the second label: assuming the labels are `[0,1]` then the softmax output for `1` is shown. In the multi-class cases, `target` defaults to the first class: for example, if the labels are `["🐶", "🐱", "🐭"]` (in that order) then the contour indicates the softmax output for `"🐶"`.

## Set Size

If `plot_set_size` is set to `true`, then the contour instead visualises the the set size.

## Univariate and Higher Dimensional Inputs

For univariate of multiple inputs (>2), this function is not applicable. See [`Plots.areaplot(conf_model::ConformalProbabilisticSet, fitresult, X, y; kwargs...)`](@ref) for an alternative way to visualize prediction for any conformal classifier.

"""
function Plots.contourf(
conf_model::ConformalProbabilisticSet,
fitresult,
X,
y;
target::Union{Nothing,Real} = nothing,
ntest = 50,
zoom = -1,
xlims = nothing,
ylims = nothing,
plot_set_size = false,
plot_classification_loss = false,
plot_set_loss = false,
temp = 0.1,
κ = 0,
loss_matrix = UniformScaling(1.0),
kwargs...,
)

# Setup:
X = permutedims(MLJBase.matrix(X))

@assert size(X, 1) == 2 "Can only create contour plot for conformal classifier with exactly two input variables."

x1 = X[1, :]
x2 = X[2, :]

# Plot limits:
xlims, ylims = generate_lims(x1, x2, xlims, ylims, zoom)

# Surface range:
x1range = range(xlims[1]; stop = xlims[2], length = ntest)
x2range = range(ylims[1]; stop = ylims[2], length = ntest)

# Target
if !isnothing(target)
@assert target in levels(y) "Specified target does not match any of the labels."
end
if length(unique(y)) > 1
if isnothing(target)
@info "No target label supplied, using first."
end
target = isnothing(target) ? levels(y)[1] : target
if plot_set_size
_default_title = "Set size"
elseif plot_set_loss
_default_title = "Smooth set loss"
elseif plot_classification_loss
_default_title = "ℒ(C,$(target))"
else
_default_title = "p̂(y=$(target))"
end
else
if plot_set_size
_default_title = "Set size"
elseif plot_set_loss
_default_title = "Smooth set loss"
elseif plot_classification_loss
_default_title = "ℒ(C,$(target-1))"
else
_default_title = "p̂(y=$(target-1))"
end
end
title = !@isdefined(title) ? _default_title : title

# Predictions
Z = []
for x2 in x2range, x1 in x1range
p̂ = MLJBase.predict(conf_model, fitresult, table([x1 x2]))[1]
if plot_set_size
z = ismissing(p̂) ? 0 : sum(pdf.(p̂, p̂.decoder.classes) .> 0)
elseif plot_classification_loss
_target = categorical([target]; levels = levels(y))
z = ConformalPrediction.ConformalTraining.classification_loss(
conf_model,
fitresult,
[x1 x2],
_target;
temp = temp,
loss_matrix = loss_matrix,
)
elseif plot_set_loss
z = ConformalPrediction.ConformalTraining.smooth_size_loss(
conf_model,
fitresult,
[x1 x2];
κ = κ,
temp = temp,
)
else
z = ismissing(p̂) ? [missing for i = 1:length(levels(y))] : pdf.(p̂, levels(y))
z = replace(z, 0 => missing)
end
push!(Z, z)
end
Z = reduce(hcat, Z)
Z = Z[findall(levels(y) .== target)[1][1], :]

# Contour:
if plot_set_size
_n = length(unique(y))
clim = (0, _n)
plt = contourf(
x1range,
x2range,
Z;
title = title,
xlims = xlims,
ylims = ylims,
c = cgrad(:blues, _n + 1; categorical = true),
clim = clim,
kwargs...,
)
else
plt = contourf(
x1range,
x2range,
Z;
title = title,
xlims = xlims,
ylims = ylims,
c = cgrad(:blues),
linewidth = 0,
kwargs...,
)
end

# Samples:
y = typeof(y) <: CategoricalArrays.CategoricalArray ? y : Int.(y)
return scatter!(plt, x1, x2; group = y, kwargs...)
end

"""
Plots.areaplot(
conf_model::ConformalProbabilisticSet, fitresult, X, y;
input_var::Union{Nothing,Int,Symbol}=nothing,
kwargs...
)

A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of any fitted conformal classifier. Using a stacked area chart, this function plots the softmax output(s) contained the the conformal predictions set on the vertical axis against an input variable `X` on the horizontal axis. In the case of multiple input variables, the `input_var` argument can be used to specify the desired input variable.
"""
function Plots.areaplot(
conf_model::ConformalProbabilisticSet,
fitresult,
X,
y;
input_var::Union{Nothing,Int,Symbol} = nothing,
kwargs...,
)

# Setup:
Xraw = deepcopy(X)
_names = get_names(Xraw)
X = permutedims(MLJBase.matrix(X))

# Dimensions:
if size(X, 1) > 1
if isnothing(input_var)
@info "Multiple inputs no input variable (`input_var`) specified: defaulting to first variable."
idx = 1
else
if typeof(input_var) == Int
idx = input_var
else
@assert input_var ∈ _names "$(input_var) is not among the variable names of `X`."
idx = findall(_names .== input_var)[1]
end
end
x = X[idx, :]
else
idx = 1
x = X
end

# Predictions:
ŷ = MLJBase.predict(conf_model, fitresult, Xraw)
nout = length(levels(y))
ŷ =
map(_y -> ismissing(_y) ? [0 for i = 1:nout] : pdf.(_y, levels(y)), ŷ) |> _y -> reduce(hcat, _y)
ŷ = permutedims(ŷ)

return areaplot(x, ŷ; kwargs...)
end

"""
Plots.plot(
conf_model::ConformalInterval, fitresult, X, y;
kwrgs...
)

A `Plots.jl` recipe/method extension that can be used to visualize the conformal predictions of a fitted conformal regressor. Data (`X`,`y`) are plotted as dots and overlaid with predictions intervals. `y` is plotted on the vertical axis against a single variable `X` on the horizontal axis. A shaded area indicates the prediction interval. The line in the center of the interval is the midpoint of the interval and can be interpreted as the point estimate of the conformal regressor. In case `X` is multi-dimensional, `input_var` can be used to specify the input variable of interest that will be used for the horizontal axis. If unspecified, the first variable will be plotting by default.
"""
function Plots.plot(
conf_model::ConformalInterval,
fitresult,
X,
y;
input_var::Union{Nothing,Int,Symbol} = nothing,
xlims::Union{Nothing,Tuple} = nothing,
ylims::Union{Nothing,Tuple} = nothing,
zoom::Real = -0.5,
train_lab::Union{Nothing,String} = nothing,
test_lab::Union{Nothing,String} = nothing,
ymid_lw::Int = 1,
kwargs...,
)

# Setup
title = !@isdefined(title) ? "" : title
train_lab = isnothing(train_lab) ? "Observed" : train_lab
test_lab = isnothing(test_lab) ? "Predicted" : test_lab

Xraw = deepcopy(X)
_names = get_names(Xraw)
X = permutedims(MLJBase.matrix(X))

# Dimensions:
if size(X, 1) > 1
if isnothing(input_var)
@info "Multivariate input for regression with no input variable (`input_var`) specified: defaulting to first variable."
idx = 1
else
if typeof(input_var) == Int
idx = input_var
else
@assert input_var ∈ _names "$(input_var) is not among the variable names of `X`."
idx = findall(_names .== input_var)[1]
end
end
x = X[idx, :]
else
idx = 1
x = X
end

# Plot limits:
xlims, ylims = generate_lims(x, y, xlims, ylims, zoom)

# Plot training data:
plt = scatter(
vec(x),
vec(y);
label = train_lab,
xlim = xlims,
ylim = ylims,
title = title,
kwargs...,
)

# Plot predictions:
ŷ = MLJBase.predict(conf_model, fitresult, Xraw)
lb, ub = eachcol(reduce(vcat, map(y -> permutedims(collect(y)), ŷ)))
ymid = (lb .+ ub) ./ 2
yerror = (ub .- lb) ./ 2
xplot = vec(x)
_idx = sortperm(xplot)
return plot!(
plt,
xplot[_idx],
ymid[_idx];
label = test_lab,
ribbon = (yerror, yerror),
lw = ymid_lw,
kwargs...,
)
end

"""
Plots.bar(conf_model::ConformalModel, fitresult, X; label="", xtickfontsize=6, kwrgs...)

A `Plots.jl` recipe/method extension that can be used to visualize the set size distribution of a conformal predictor. In the regression case, prediction interval widths are stratified into discrete bins. It can be useful to plot the distribution of set sizes in order to visually asses how adaptive a conformal predictor is. For more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs”. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should vary across the spectrum of ‘empty set’ to ‘all labels included’.
"""
function Plots.bar(
conf_model::ConformalModel,
fitresult,
X;
label = "",
xtickfontsize = 6,
kwrgs...,
)
ŷ = MLJBase.predict(conf_model, fitresult, X)
idx = ConformalPrediction.size_indicator(ŷ)
x = sort(levels(idx); lt = natural)
y = [sum(idx .== _x) for _x in x]
return Plots.bar(x, y; label = label, xtickfontsize = xtickfontsize, kwrgs...)
end
include("regression.jl")
include("bar.jl")
include("classification.jl")
Loading
Loading