-
Notifications
You must be signed in to change notification settings - Fork 7
/
invalidation_rate.jl
83 lines (70 loc) · 2.69 KB
/
invalidation_rate.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
using Distributions: Distributions
using Flux: Flux
using LinearAlgebra: LinearAlgebra
Base.@kwdef struct InvalidationRateConvergence <: AbstractConvergence
invalidation_rate::AbstractFloat = 0.1
max_iter::Int = 100
variance::AbstractFloat = 0.01
end
"""
converged(
convergence::InvalidationRateConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is invalidation rate.
"""
function converged(
convergence::InvalidationRateConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
ir = invalidation_rate(ce)
label = Models.predict_label(ce.M, ce.data, ce.x′)[1]
return label == ce.target && convergence.invalidation_rate > ir
end
"""
invalidation_rate(ce::AbstractCounterfactualExplanation)
Calculates the invalidation rate of a counterfactual explanation.
# Arguments
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation to calculate the invalidation rate for.
- `kwargs`: Additional keyword arguments to pass to the function.
# Returns
The invalidation rate of the counterfactual explanation.
"""
function invalidation_rate(ce::AbstractCounterfactualExplanation)
index_target = findfirst(map(x -> x == ce.target, ce.data.y_levels))
f_loss = logits(ce.M, CounterfactualExplanations.decode_state(ce))[index_target]
grad = []
for i in 1:length(ce.s′)
push!(
grad,
Flux.gradient(
() -> logits(ce.M, CounterfactualExplanations.decode_state(ce))[i],
Flux.params(ce.s′),
)[ce.s′],
)
end
gradᵀ = LinearAlgebra.transpose(grad)
identity_matrix = LinearAlgebra.Matrix{Float32}(
LinearAlgebra.I, length(grad), length(grad)
)
denominator = sqrt(gradᵀ * ce.convergence.variance * identity_matrix * grad)[1]
normalized_gradient = f_loss / denominator
ϕ = Distributions.cdf(Distributions.Normal(0, 1), normalized_gradient)
return 1 - ϕ
end
"""
hinge_loss(convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation)
Calculates the hinge loss of a counterfactual explanation.
# Arguments
- `convergence::InvalidationRateConvergence`: The convergence criterion to use.
- `ce::AbstractCounterfactualExplanation`: The counterfactual explanation to calculate the hinge loss for.
# Returns
The hinge loss of the counterfactual explanation.
"""
function hinge_loss(
convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation
)
return max(0, invalidation_rate(ce) - convergence.invalidation_rate)
end