Skip to content

Commit

Permalink
Merge branch '32-move-to-plot-recipes-instead-of-overloading'
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Sep 5, 2024
2 parents ba55aa0 + a37dd26 commit b32b5dc
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 124 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Format Check

on:
push:
branches:
- 'main'
- 'release-'
tags: ['*']
pull_request:

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: julia-actions/setup-julia@latest
with:
version: 1
- uses: actions/checkout@v1
- name: Install JuliaFormatter
run: |
using Pkg
Pkg.add("JuliaFormatter")
shell: julia --color=yes {0}
- name: Format code
run: |
using JuliaFormatter
format("."; verbose=true)
shell: julia --color=yes {0}
2 changes: 1 addition & 1 deletion src/ConformalPrediction/ConformalPrediction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ end

include("regression.jl")
include("bar.jl")
include("classification.jl")
include("classification.jl")
9 changes: 2 additions & 7 deletions src/ConformalPrediction/bar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
A `Plots.jl` recipe that can be used to visualize the set size distribution of a conformal predictor. In the regression case, prediction interval widths are stratified into discrete bins. It can be useful to plot the distribution of set sizes in order to visually asses how adaptive a conformal predictor is. For more adaptive predictors the distribution of set sizes is typically spread out more widely, which reflects that “the procedure is effectively distinguishing between easy and hard inputs”. This is desirable: when for a given sample it is difficult to make predictions, this should be reflected in the set size (or interval width in the regression case). Since ‘difficult’ lies on some spectrum that ranges from ‘very easy’ to ‘very difficult’ the set size should vary across the spectrum of ‘empty set’ to ‘all labels included’.
"""
@recipe function plot(
conf_model::ConformalModel,
fitresult,
X
)
@recipe function plot(conf_model::ConformalModel, fitresult, X)

# Plot attributes:
xtickfontsize --> 6
Expand All @@ -28,5 +24,4 @@ A `Plots.jl` recipe that can be used to visualize the set size distribution of a
label --> ""
x, y
end

end
end
55 changes: 33 additions & 22 deletions src/ConformalPrediction/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,21 @@ In the case of univariate inputs or higher dimensional inputs, a stacked area pl
# Predictions:
= MLJBase.predict(conf_model, fitresult, Xraw)
nout = length(levels(y))
=
map(_y -> ismissing(_y) ? [0 for i = 1:nout] : pdf.(_y, levels(y)), ŷ) |> _y -> reduce(hcat, _y)
= (_y -> reduce(hcat, _y))(map(
_y -> ismissing(_y) ? [0 for i in 1:nout] : pdf.(_y, levels(y)), ŷ
))
= permutedims(ŷ)
println(x)
println(ŷ[sortperm(x), :])

# Area chart
args = (x, ŷ)
data = cumsum(args[end], dims=2)
data = cumsum(args[end]; dims=2)
x = length(args) == 1 ? (axes(data, 1)) : args[1]
seriestype := :line
for i in axes(data, 2)
@series begin
fillrange := i > 1 ? data[:, i-1] : 0
fillrange := i > 1 ? data[:, i - 1] : 0
x, data[:, i]
end
end
Expand All @@ -114,8 +115,21 @@ In the case of univariate inputs or higher dimensional inputs, a stacked area pl

# Setup:
x1, x2, x1range, x2range, Z, xlims, ylims, _default_title = setup_contour_cp(
conf_model, fitresult, X, y, xlims, ylims, zoom, ntest, target,
plot_set_size, plot_classification_loss, plot_set_loss, temp, κ, loss_matrix,
conf_model,
fitresult,
X,
y,
xlims,
ylims,
zoom,
ntest,
target,
plot_set_size,
plot_classification_loss,
plot_set_loss,
temp,
κ,
loss_matrix,
)

# Contour:
Expand All @@ -136,13 +150,19 @@ In the case of univariate inputs or higher dimensional inputs, a stacked area pl
x1[group_idx], x2[group_idx]
end
end

end

end

function setup_contour_cp(
conf_model, fitresult, X, y, xlims, ylims, zoom, ntest, target,
conf_model,
fitresult,
X,
y,
xlims,
ylims,
zoom,
ntest,
target,
plot_set_size,
plot_classification_loss,
plot_set_loss,
Expand Down Expand Up @@ -201,23 +221,14 @@ function setup_contour_cp(
elseif plot_classification_loss
_target = categorical([target]; levels=levels(y))
z = ConformalPrediction.ConformalTraining.classification_loss(
conf_model,
fitresult,
[x1 x2],
_target;
temp=temp,
loss_matrix=loss_matrix,
conf_model, fitresult, [x1 x2], _target; temp=temp, loss_matrix=loss_matrix
)
elseif plot_set_loss
z = ConformalPrediction.ConformalTraining.smooth_size_loss(
conf_model,
fitresult,
[x1 x2];
κ=κ,
temp=temp,
conf_model, fitresult, [x1 x2]; κ=κ, temp=temp
)
else
z = ismissing(p̂) ? [missing for i = 1:length(levels(y))] : pdf.(p̂, levels(y))
z = ismissing(p̂) ? [missing for i in 1:length(levels(y))] : pdf.(p̂, levels(y))
z = replace(z, 0 => missing)
end
push!(Z, z)
Expand All @@ -226,4 +237,4 @@ function setup_contour_cp(
Z = Z[findall(levels(y) .== target)[1][1], :]

return x1, x2, x1range, x2range, Z, xlims, ylims, _default_title
end
end
4 changes: 1 addition & 3 deletions src/ConformalPrediction/regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ A `Plots.jl` recipe that can be used to visualize the conformal predictions of a
label := train_lab
vec(x), vec(y)
end

end

function setup_ci(X, y, input_var, xlims, ylims, zoom)

Xraw = deepcopy(X)
_names = get_names(Xraw)
X = permutedims(MLJBase.matrix(X))
Expand Down Expand Up @@ -88,4 +86,4 @@ function setup_ci(X, y, input_var, xlims, ylims, zoom)
xlims, ylims = generate_lims(x, y, xlims, ylims, zoom)

return x, y, xlims, ylims, Xraw
end
end
42 changes: 16 additions & 26 deletions src/CounterfactualExplations/counterfactuals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ Calling `Plots.plot` on a `CounterfactualExplanation` object will plot the train
dim_red=:pca,
plot_loss=false,
loss_fun=nothing,
plot_up_to = nothing,
n_points = nothing,
plot_up_to=nothing,
n_points=nothing,
)

if !isnothing(n_points)
if n_points < size(ce.data.X, 2)
@info "Undersampling to $(n_points) points."
Expand All @@ -35,7 +34,7 @@ Calling `Plots.plot` on a `CounterfactualExplanation` object will plot the train
ce = deepcopy(ce)
ce.data = DataPreprocessing.subsample(ce.data, n_points)
else
xlims, ylims = nothing, nothing
xlims, ylims = nothing, nothing
end

# Asserts
Expand All @@ -45,23 +44,14 @@ Calling `Plots.plot` on a `CounterfactualExplanation` object will plot the train
xlims = get(plotattributes, :xlims, xlims)
ylims = get(plotattributes, :ylims, ylims)
ms = get(plotattributes, :markersize, 3)
mspath = ms*2
msfinal = mspath*2
mspath = ms * 2
msfinal = mspath * 2

# Plot attributes
linewidth --> 0.1

contour_series, X, y, xlims, ylims = setup_model_plot(
ce.M,
ce.data,
target,
length_out,
zoom,
dim_red,
plot_loss,
loss_fun,
xlims,
ylims,
ce.M, ce.data, target, length_out, zoom, dim_red, plot_loss, loss_fun, xlims, ylims
)

xlims --> xlims
Expand Down Expand Up @@ -95,20 +85,21 @@ Calling `Plots.plot` on a `CounterfactualExplanation` object will plot the train
path_x, path_y = setup_ce_plot(ce)

# Outer loop over number of counterfactuals:
for (num_counterfactual, X) in enumerate(eachslice(path_x, dims=3))
for (num_counterfactual, X) in enumerate(eachslice(path_x; dims=3))
# Inner loop over counterfactual search steps:
steps = zip(eachcol(X), path_y)
for (i,(x,y)) in enumerate(steps)
steps = zip(eachcol(X), path_y)
for (i, (x, y)) in enumerate(steps)
i <= max_iter || break
_final_iter = i == length(steps) || i == max_iter
_annotate = i == length(steps) && ce.num_counterfactuals > 1
@series begin
seriestype := :scatter
markercolor := CategoricalArrays.levelcode.(y[num_counterfactual])
markersize := _final_iter ? msfinal : mspath
series_annotation := _annotate ? text("C$(num_counterfactual)", mspath) : nothing
series_annotation :=
_annotate ? text("C$(num_counterfactual)", mspath) : nothing
label := :none
x[1,:], x[2,:]
x[1, :], x[2, :]
end
end
end
Expand All @@ -131,12 +122,11 @@ animate_path(ce)
"""
function animate_path(
ce::CounterfactualExplanation,
path = tempdir();
plot_up_to::Union{Nothing,Int} = nothing,
legend = :topright,
path=tempdir();
plot_up_to::Union{Nothing,Int}=nothing,
legend=:topright,
kwrgs...,
)

max_iter = total_steps(ce)
max_iter = if isnothing(plot_up_to)
total_steps(ce)
Expand All @@ -145,7 +135,7 @@ function animate_path(
end
max_iter += 1

anim = @animate for t = 1:max_iter
anim = @animate for t in 1:max_iter
plot(ce; plot_up_to=t, legend=legend, kwrgs...)
end
return anim
Expand Down
18 changes: 9 additions & 9 deletions src/CounterfactualExplations/data.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using MLUtils

function embed(data::CounterfactualData, X::AbstractArray = nothing; dim_red::Symbol = :pca)
function embed(data::CounterfactualData, X::AbstractArray=nothing; dim_red::Symbol=:pca)

# Training compressor:
if typeof(data.input_encoder) <: MultivariateStats.AbstractDimensionalityReduction
Expand All @@ -12,19 +12,19 @@ function embed(data::CounterfactualData, X::AbstractArray = nothing; dim_red::Sy
else
@info "Training model to compress data."
if dim_red == :pca
tfn = MultivariateStats.fit(PCA, X_train; maxoutdim = 2)
tfn = MultivariateStats.fit(PCA, X_train; maxoutdim=2)
elseif dim_red == :tsne
tfn = MultivariateStats.fit(TSNE, X_train; maxoutdim = 2)
tfn = MultivariateStats.fit(TSNE, X_train; maxoutdim=2)
end
data.input_encoder = nothing
X = isnothing(X) ? X_train : X
end
end

# Transforming:
X = typeof(X) <: Vector{<:Matrix} ? MLUtils.stack(X, dims = 2) : X
X = typeof(X) <: Vector{<:Matrix} ? MLUtils.stack(X; dims=2) : X
if !isnothing(tfn) && !isnothing(X)
X = mapslices(x -> MultivariateStats.predict(tfn, x), X, dims = 1)
X = mapslices(x -> MultivariateStats.predict(tfn, x), X; dims=1)
else
X = isnothing(X) ? X_train : X
end
Expand All @@ -42,13 +42,13 @@ function embed_path(ce::CounterfactualExplanation)
return embed(data_, path(ce))
end

function prepare_for_plotting(data::CounterfactualData; dim_red::Symbol = :pca)
function prepare_for_plotting(data::CounterfactualData; dim_red::Symbol=:pca)
X, _ = DataPreprocessing.unpack_data(data)
y = data.output_encoder.labels
@assert size(X, 1) != 1 "Don't know how to plot 1-dimensional data."
multi_dim = size(X, 1) > 2
if multi_dim
X = embed(data, X; dim_red = dim_red)
X = embed(data, X; dim_red=dim_red)
end
return X', y, multi_dim
end
Expand All @@ -58,10 +58,10 @@ end
Calling `Plots.plot` on a `data::CounterfactualData` object will generate a scatter plot of the data.
"""
@recipe function plot(data::CounterfactualData; dim_red = :pca)
@recipe function plot(data::CounterfactualData; dim_red=:pca)

# Set up:
X, y, _ = prepare_for_plotting(data; dim_red = dim_red)
X, y, _ = prepare_for_plotting(data; dim_red=dim_red)

# Scatter plot:
for (i, x) in enumerate(unique(sort(y)))
Expand Down
Loading

2 comments on commit b32b5dc

@pat-alt
Copy link
Member Author

@pat-alt pat-alt commented on b32b5dc Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Moved from Plots.jl method overloading to RecipesBase to avoid type piracy.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/114589

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.3.0 -m "<description of version>" b32b5dc602e839a595dc1d87fb55d2440fe6a1a7
git push origin v1.3.0

Please sign in to comment.