-
Notifications
You must be signed in to change notification settings - Fork 7
/
measures.jl
38 lines (31 loc) · 1.46 KB
/
measures.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
using Statistics: Statistics
include("faithfulness/faithfulness.jl")
include("plausibility/plausibility.jl")
"""
validity(ce::CounterfactualExplanation; γ=0.5)
Checks of the counterfactual search has been successful with respect to the probability threshold `γ`. In case multiple counterfactuals were generated, the function returns the proportion of successful counterfactuals.
"""
function validity(ce::CounterfactualExplanation; agg=Statistics.mean, γ=0.5)
val = agg(CounterfactualExplanations.target_probs(ce) .>= γ)
val = val isa LinearAlgebra.AbstractMatrix ? vec(val) : val
return val
end
"""
validity_strict(ce::CounterfactualExplanation)
Checks if the counterfactual search has been strictly valid in the sense that it has converged with respect to the pre-specified target probability `γ`.
"""
function validity_strict(ce::CounterfactualExplanation)
return validity(ce; γ=ce.convergence.decision_threshold)
end
"""
redundancy(ce::CounterfactualExplanation)
Computes the feature redundancy: that is, the number of features that remain unchanged from their original, factual values.
"""
function redundancy(ce::CounterfactualExplanation; agg=Statistics.mean, tol=1e-5)
x′ = CounterfactualExplanations.counterfactual(ce)
redundant_x = [
agg(sum(abs.(x .- ce.x) .< tol) / size(x, 1)) for x in eachslice(x′; dims=ndims(x′))
]
redundant_x = length(redundant_x) == 1 ? redundant_x[1] : redundant_x
return redundant_x
end