Skip to content

Commit

Permalink
plots for CE can now also plot loss contour
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Apr 11, 2024
1 parent cee2185 commit b08a87d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TaijaPlotting"
uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240"
authors = ["Patrick Altmeyer"]
version = "1.0.10"
version = "1.1.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -21,12 +21,12 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[compat]
CategoricalArrays = "0.10"
ConformalPrediction = "0.1"
ConformalPrediction = "0.1, 1"
CounterfactualExplanations = "0.1, 1"
DataAPI = "1"
Distributions = "0.25"
Flux = "0.12, 0.13, 0.14"
LaplaceRedux = "0.1, 0.2"
LaplaceRedux = "0.1, 0.2, 1"
LinearAlgebra = "1.7, 1.8, 1.9"
MLJBase = "0.21, 0.22, 1"
MLUtils = "0.4"
Expand Down
56 changes: 34 additions & 22 deletions src/CounterfactualExplations/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ function Plots.plot(
alpha = 1.0,
contour_alpha = 1.0,
dim_red::Symbol = :pca,
plot_loss::Bool = false,
loss_fun::Union{Nothing,Function} = nothing,
kwargs...,
)
X, _ = DataPreprocessing.unpack_data(data)
Expand All @@ -26,6 +28,13 @@ function Plots.plot(
= vec(ŷ)
end

# Target:
if isnothing(target)
target = data.y_levels[1]
@info "No target label supplied, using first."
end
target_encoded = data.output_encoder(target)

X, y, multi_dim = prepare_for_plotting(data; dim_red = dim_red)

# Surface range:
Expand All @@ -43,40 +52,43 @@ function Plots.plot(
x_range = convert.(eltype(X), range(xlims[1]; stop = xlims[2], length = length_out))
y_range = convert.(eltype(X), range(ylims[1]; stop = ylims[2], length = length_out))

if multi_dim
knn1, y_train = voronoi(X, ŷ)
predict_ =
(X::AbstractVector) -> vec(
pdf(
MLJBase.predict(knn1, MLJBase.table(reshape(X, 1, 2))),
DataAPI.levels(y_train),
),
)
Z = [predict_([x, y]) for x in x_range, y in y_range]
plot_loss = plot_loss || !isnothing(loss_fun)

if plot_loss
# Loss surface:
Z = [loss_fun(logits(M, [x, y][:, :]), target_encoded) for x in x_range, y in y_range]
else
predict_ = function (X::AbstractVector)
X = permutedims(permutedims(X))
z = predict_proba(M, data, X)
return z
# Prediction surface:
if multi_dim
knn1, y_train = voronoi(X, ŷ)
predict_ =
(X::AbstractVector) -> vec(
pdf(
MLJBase.predict(knn1, MLJBase.table(reshape(X, 1, 2))),
DataAPI.levels(y_train),
),
)
Z = [predict_([x, y]) for x in x_range, y in y_range]
else
predict_ = function (X::AbstractVector)
X = permutedims(permutedims(X))
z = predict_proba(M, data, X)
return z
end
Z = [predict_([x, y]) for x in x_range, y in y_range]
end
Z = [predict_([x, y]) for x in x_range, y in y_range]
end

# Pre-processes:
Z = reduce(hcat, Z)
if isnothing(target)
target = data.y_levels[1]
if size(Z, 1) > 2
@info "No target label supplied, using first."
end
end
target_idx = get_target_index(data.y_levels, target)
z = plot_loss ? Z[1, :] : Z[target_idx, :]

# Contour:
Plots.contourf(
x_range,
y_range,
Z[Int(target_idx), :];
z;
colorbar = colorbar,
title = title,
linewidth = linewidth,
Expand Down

2 comments on commit b08a87d

@pat-alt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Adds support for plotting the loss landscape for the counterfactual search in the two-dimensional (latent) space.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/104687

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.0 -m "<description of version>" b08a87d5fb31dd36f68ae042f909f55c14c024ac
git push origin v1.1.0

Please sign in to comment.